Skip to content

Method _search_where from AsyncPostgresSaver doesn't encode the filters correctly #294

@svelezdevilla

Description

@svelezdevilla

I've decided to use langchain_google_cloud_sql_pg as a checkpointer for a langgraph application. The checkpointer is initialized like this:

from langchain_google_cloud_sql_pg import PostgresEngine
from langchain_google_cloud_sql_pg import PostgresSaver

from .config import settings

async def checkpointer():
    engine = await PostgresEngine.afrom_instance(
        project_id=settings.cp_project_id,
        region=settings.cp_region,
        instance=settings.cp_instance,
        database=settings.cp_database,
        user=settings.cp_username,
        password=settings.cp_password
    )
    engine.init_checkpoint_table()
    checkpointer = await PostgresSaver.create(
        engine=engine
    )

    return checkpointer

The graph is then initialized like this:

async def initialize():
    return builder.compile(checkpointer=await checkpointer())

And chatting with the agent is done through a FastAPI endpoint:

response = await graph.ainvoke(input_state, config=config)

My application requires to have different users, each having their list of threads. For fetching the list of threads per user, I added an endpoint to my FastAPI application:

@app.get("/threads")
async def get_threads_for_user(
    user_id: str = Depends(get_user_id)
) -> Dict[str, Any]:
    """Retrieve all thread IDs for a given user, ordered by first appearance."""
    checkpoints = graph.checkpointer.alist(config=None, filter={"user_id": user_id})

    threads = []
    seen = set()

    async for checkpoint in checkpoints:
        thread_id = checkpoint.metadata.get("thread_id")
        if thread_id and thread_id not in seen:
            seen.add(thread_id)
            threads.append(thread_id)

    return {
        "data": {
            "thread_ids": threads
        }
    }

This is probably not the most efficient way to get all the thread_ids for a given user, but I didn't find any alternative.

Up until now, I used the built-in langgraph's MemorySaver (in memory), and wanted to go a step further with Cloud SQL. The above endpoint worked perfectly with the MemorySaver as checkpointer.

After switching to langchain_google_cloud_sql_pg, I noticed an error when filtering for metadata (i.e., using the filter parameter of the alist method. Overall, it comes down to the following error:

sqlalchemy.exc.DBAPIError: (sqlalchemy.dialects.postgresql.asyncpg.Error) <class 'asyncpg.exceptions.InvalidTextRepresentationError'>: invalid input syntax for type json
DETAIL:  Token "What" is invalid.
[SQL: 
        SELECT
            thread_id,
            checkpoint,
            checkpoint_ns,
            checkpoint_id,
            parent_checkpoint_id,
            metadata,
            type,
            (
                SELECT array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx)
                FROM "public"."checkpoints_writes" cw
                where cw.thread_id = c.thread_id
                    AND cw.checkpoint_ns = c.checkpoint_ns
                    AND cw.checkpoint_id = c.checkpoint_id
            ) AS pending_writes,
            (
                SELECT array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx)
                FROM "public"."checkpoints_writes" cw
                WHERE cw.thread_id = c.thread_id
                    AND cw.checkpoint_ns = c.checkpoint_ns
                    AND cw.checkpoint_id = c.parent_checkpoint_id
                    AND cw.channel = '__pregel_tasks'
            ) AS pending_sends
        FROM "public"."checkpoints" c
        WHERE encode(metadata,'escape')::jsonb @> $1  ORDER BY checkpoint_id DESC]
[parameters: ('{"user_id": "1234567890"}',)]

Which I was able to pinpoint to the last WHERE clause, caused by an encoding issue.

In the class AsyncPostgresSaver, there's the method _search_where, which has the following lines:

        if filter:
            wheres.append("encode(metadata,'escape')::jsonb @> :metadata ")
            param_values.update({"metadata": f"{json.dumps(filter)}"})

Modifying this part to the lines below fixes the issue, and permits filtering by user_id:

        if filter:
            wheres.append("convert_from(metadata,'UTF8')::jsonb @> :metadata ")
            param_values.update({"metadata": f"{json.dumps(filter)}"})

Additionally, I also tried the filter option with other fields, and it didn't work, except after modifying the lines as mentioned above.

Metadata

Metadata

Assignees

Labels

api: cloudsql-postgresIssues related to the googleapis/langchain-google-cloud-sql-pg-python API.priority: p1Important issue which blocks shipping the next release. Will be fixed prior to next release.type: bugError or flaw in code with unintended results or allowing sub-optimal usage patterns.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions