I've decided to use langchain_google_cloud_sql_pg as a checkpointer for a langgraph application. The checkpointer is initialized like this:
from langchain_google_cloud_sql_pg import PostgresEngine
from langchain_google_cloud_sql_pg import PostgresSaver
from .config import settings
async def checkpointer():
engine = await PostgresEngine.afrom_instance(
project_id=settings.cp_project_id,
region=settings.cp_region,
instance=settings.cp_instance,
database=settings.cp_database,
user=settings.cp_username,
password=settings.cp_password
)
engine.init_checkpoint_table()
checkpointer = await PostgresSaver.create(
engine=engine
)
return checkpointer
The graph is then initialized like this:
async def initialize():
return builder.compile(checkpointer=await checkpointer())
And chatting with the agent is done through a FastAPI endpoint:
response = await graph.ainvoke(input_state, config=config)
My application requires to have different users, each having their list of threads. For fetching the list of threads per user, I added an endpoint to my FastAPI application:
@app.get("/threads")
async def get_threads_for_user(
user_id: str = Depends(get_user_id)
) -> Dict[str, Any]:
"""Retrieve all thread IDs for a given user, ordered by first appearance."""
checkpoints = graph.checkpointer.alist(config=None, filter={"user_id": user_id})
threads = []
seen = set()
async for checkpoint in checkpoints:
thread_id = checkpoint.metadata.get("thread_id")
if thread_id and thread_id not in seen:
seen.add(thread_id)
threads.append(thread_id)
return {
"data": {
"thread_ids": threads
}
}
This is probably not the most efficient way to get all the thread_ids for a given user, but I didn't find any alternative.
Up until now, I used the built-in langgraph's MemorySaver (in memory), and wanted to go a step further with Cloud SQL. The above endpoint worked perfectly with the MemorySaver as checkpointer.
After switching to langchain_google_cloud_sql_pg, I noticed an error when filtering for metadata (i.e., using the filter parameter of the alist method. Overall, it comes down to the following error:
sqlalchemy.exc.DBAPIError: (sqlalchemy.dialects.postgresql.asyncpg.Error) <class 'asyncpg.exceptions.InvalidTextRepresentationError'>: invalid input syntax for type json
DETAIL: Token "What" is invalid.
[SQL:
SELECT
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
type,
(
SELECT array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx)
FROM "public"."checkpoints_writes" cw
where cw.thread_id = c.thread_id
AND cw.checkpoint_ns = c.checkpoint_ns
AND cw.checkpoint_id = c.checkpoint_id
) AS pending_writes,
(
SELECT array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx)
FROM "public"."checkpoints_writes" cw
WHERE cw.thread_id = c.thread_id
AND cw.checkpoint_ns = c.checkpoint_ns
AND cw.checkpoint_id = c.parent_checkpoint_id
AND cw.channel = '__pregel_tasks'
) AS pending_sends
FROM "public"."checkpoints" c
WHERE encode(metadata,'escape')::jsonb @> $1 ORDER BY checkpoint_id DESC]
[parameters: ('{"user_id": "1234567890"}',)]
Which I was able to pinpoint to the last WHERE clause, caused by an encoding issue.
In the class AsyncPostgresSaver, there's the method _search_where, which has the following lines:
if filter:
wheres.append("encode(metadata,'escape')::jsonb @> :metadata ")
param_values.update({"metadata": f"{json.dumps(filter)}"})
Modifying this part to the lines below fixes the issue, and permits filtering by user_id:
if filter:
wheres.append("convert_from(metadata,'UTF8')::jsonb @> :metadata ")
param_values.update({"metadata": f"{json.dumps(filter)}"})
Additionally, I also tried the filter option with other fields, and it didn't work, except after modifying the lines as mentioned above.
I've decided to use
langchain_google_cloud_sql_pgas a checkpointer for a langgraph application. The checkpointer is initialized like this:The graph is then initialized like this:
And chatting with the agent is done through a FastAPI endpoint:
My application requires to have different users, each having their list of threads. For fetching the list of threads per user, I added an endpoint to my FastAPI application:
This is probably not the most efficient way to get all the thread_ids for a given user, but I didn't find any alternative.
Up until now, I used the built-in langgraph's
MemorySaver(in memory), and wanted to go a step further with Cloud SQL. The above endpoint worked perfectly with theMemorySaveras checkpointer.After switching to
langchain_google_cloud_sql_pg, I noticed an error when filtering for metadata (i.e., using thefilterparameter of thealistmethod. Overall, it comes down to the following error:Which I was able to pinpoint to the last WHERE clause, caused by an encoding issue.
In the class
AsyncPostgresSaver, there's the method_search_where, which has the following lines:Modifying this part to the lines below fixes the issue, and permits filtering by user_id:
Additionally, I also tried the filter option with other fields, and it didn't work, except after modifying the lines as mentioned above.