Skip to content

SessionIndexing

Michael Bayer edited this page Dec 5, 2025 · 5 revisions

SessionIndexing

This recipe presents a generalized way to "index" objects in memory as they are placed into Sessions, so that later they can be retrieved based on particular criteria. The use case for this could be to assist in writing before_flush() event handlers, where particular subsets of objects in a Session need to be inspected, and a SQL round trip is specifically not wanted, typically due to performance concerns.

The technique is actually pretty simplistic, and does not account for the case where the objects are mutated in the Session, such that the object would be indexed differently. To handle that, attribute-on-change events would also need to be intercepted, resulting in a re-indexing of a particular target index.

In constrast to this recipe, it is of course vastly simpler just to use the Session normally, emitting a query against the database whose results are then correlated against what's already in the Sessions' identity map; the use case here is specifically one of avoiding those round trips.

import weakref
import collections
from sqlalchemy import event
from sqlalchemy.orm import Session


class Index:
    """An in-memory 'index' of objects in sessions.

    Listens for objects being attached to sessions and
    indexes them according to a series of user-defined "indexing"
    functions.

    """

    def _index_object(self, session, instance, cls, name, fn):
        # object attached to a session

        if not isinstance(instance, cls):
            return

        if "_index" not in session.info:
            session.info["_index"] = _index = collections.defaultdict(
                weakref.WeakSet
            )
        else:
            _index = session.info["_index"]

        key = name, fn(instance)
        _index[key].add(instance)

    def indexed(self, cls, name):
        """Log a function as indexing a certain class."""

        def decorate(fn):
            @event.listens_for(cls, "load", propagate=True)
            def object_loaded(instance, ctx):
                self._index_object(ctx.session, instance, cls, name, fn)

            @event.listens_for(Session, "after_attach")
            def index_object(session, instance):
                self._index_object(session, instance, cls, name, fn)

            return fn

        return decorate

    def __getattr__(self, name):
        """Return an index-lookup function."""

        def go(sess, value):
            by_session = sess.info.get("_index")
            if by_session is None:
                return set()
            key = name, value
            return set(by_session[key]).intersection(
                set(sess.identity_map.values()).union(sess.new)
            )

        return go


indexes = Index()

if __name__ == "__main__":
    # demonstration

    from sqlalchemy import Column, String, Integer
    from sqlalchemy.orm import Session, DeclarativeBase

    class Base(DeclarativeBase):
        pass

    class User(Base):
        __tablename__ = "user"

        id = Column(Integer, primary_key=True)
        name = Column(String)

    class Address(Base):
        __tablename__ = "address"

        id = Column(Integer, primary_key=True)
        name = Column(String)

    @indexes.indexed(User, "user_byname")
    def index_user_byname(obj):
        return obj.name

    @indexes.indexed(Address, "address_byname")
    def index_address_byname(obj):
        return obj.name

    a1, a2, a3 = User(name="a"), User(name="a"), User(name="a")
    b1, b2, b3 = User(name="b"), User(name="b"), User(name="b")
    c1, c2, c3 = User(name="c"), User(name="c"), User(name="c")
    d1, d2, d3 = User(name="d"), User(name="d"), User(name="d")
    e1, e2, e3 = User(name="e"), User(name="e"), User(name="e")

    ad_a, ad_b, ad_c = Address(name="a"), Address(name="b"), Address(name="c")

    s1, s2, s3 = Session(), Session(), Session()

    s1.add_all([a1, b1, b2, d2, e3, ad_c])
    s2.add_all([a2, c2, e1, e2, ad_a])
    s3.add_all([b3, c1, d1, d3, ad_b])

    assert indexes.user_byname(s1, "b") == set([b1, b2])
    assert indexes.user_byname(s2, "e") == set([e1, e2])
    assert indexes.user_byname(s2, "c") == set([c2])
    assert indexes.user_byname(s3, "b") == set([b3])
    assert indexes.address_byname(s3, "b") == set([ad_b])
    assert indexes.address_byname(s3, "c") == set()
    assert indexes.address_byname(s1, "c") == set([ad_c])

    s2.expunge(e2)
    assert indexes.user_byname(s2, "e") == set([e1])
    s2.close()
    assert indexes.user_byname(s2, "e") == set([])

Clone this wiki locally