Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/langchain_google_cloud_sql_pg/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@


async def _aget_messages(
engine: PostgresEngine, session_id: str, table_name: str
engine: PostgresEngine,
session_id: str,
table_name: str,
schema_name: str = "public",
) -> List[BaseMessage]:
"""Retrieve the messages from PostgreSQL."""
query = f"""SELECT data, type FROM "{table_name}" WHERE session_id = :session_id ORDER BY id;"""
query = f"""SELECT data, type FROM "{schema_name}"."{table_name}" WHERE session_id = :session_id ORDER BY id;"""
results = await engine._afetch(query, {"session_id": session_id})
if not results:
return []
Expand All @@ -49,6 +52,7 @@ def __init__(
session_id: str,
table_name: str,
messages: List[BaseMessage],
schema_name: str = "public",
):
"""PostgresChatMessageHistory constructor.

Expand All @@ -58,6 +62,7 @@ def __init__(
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
messages (List[BaseMessage]): Messages to store.
schema_name (str, optional): Database schema name of the chat message history table. Defaults to "public".

Raises:
Exception: If constructor is directly called by the user.
Expand All @@ -70,73 +75,80 @@ def __init__(
self.session_id = session_id
self.table_name = table_name
self.messages = messages
self.schema_name = schema_name

@classmethod
async def create(
cls,
engine: PostgresEngine,
session_id: str,
table_name: str,
schema_name: str = "public",
) -> PostgresChatMessageHistory:
"""Create a new PostgresChatMessageHistory instance.

Args:
engine (PostgresEngine): Postgres engine to use.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
schema_name (str, optional): Schema name for the chat message history table. Defaults to "public".

Raises:
IndexError: If the table provided does not contain required schema.

Returns:
PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory.
"""
table_schema = await engine._aload_table_schema(table_name)
table_schema = await engine._aload_table_schema(table_name, schema_name)
column_names = table_schema.columns.keys()

required_columns = ["id", "session_id", "data", "type"]

if not (all(x in column_names for x in required_columns)):
raise IndexError(
f"Table '{table_name}' has incorrect schema. Got "
f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got "
f"column names '{column_names}' but required column names "
f"'{required_columns}'.\nPlease create table with following schema:"
f"\nCREATE TABLE {table_name} ("
f"\nCREATE TABLE {schema_name}.{table_name} ("
"\n id INT AUTO_INCREMENT PRIMARY KEY,"
"\n session_id TEXT NOT NULL,"
"\n data JSON NOT NULL,"
"\n type TEXT NOT NULL"
"\n);"
)
messages = await _aget_messages(engine, session_id, table_name)
return cls(cls.__create_key, engine, session_id, table_name, messages)
messages = await _aget_messages(engine, session_id, table_name, schema_name)
return cls(
cls.__create_key, engine, session_id, table_name, messages, schema_name
)

@classmethod
def create_sync(
cls,
engine: PostgresEngine,
session_id: str,
table_name: str,
schema_name: str = "public",
) -> PostgresChatMessageHistory:
"""Create a new PostgresChatMessageHistory instance.

Args:
engine (PostgresEngine): Postgres engine to use.
session_id (str): Retrieve the table content with this session ID.
table_name (str): Table name that stores the chat message history.
schema_name (str, optional): Database schema name for the chat message history table. Defaults to "public".

Raises:
IndexError: If the table provided does not contain required schema.

Returns:
PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory.
"""
coro = cls.create(engine, session_id, table_name)
coro = cls.create(engine, session_id, table_name, schema_name)
return engine._run_as_sync(coro)

async def aadd_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
query = f"""INSERT INTO "{self.table_name}"(session_id, data, type)
query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type)
VALUES (:session_id, :data, :type);
"""
await self.engine._aexecute(
Expand All @@ -148,7 +160,7 @@ async def aadd_message(self, message: BaseMessage) -> None:
},
)
self.messages = await _aget_messages(
self.engine, self.session_id, self.table_name
self.engine, self.session_id, self.table_name, self.schema_name
)

def add_message(self, message: BaseMessage) -> None:
Expand All @@ -166,7 +178,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:

async def aclear(self) -> None:
"""Clear session memory from PostgreSQL"""
query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;"""
query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;"""
await self.engine._aexecute(query, {"session_id": self.session_id})
self.messages = []

Expand All @@ -177,7 +189,7 @@ def clear(self) -> None:
async def async_messages(self) -> None:
"""Retrieve the messages from Postgres."""
self.messages = await _aget_messages(
self.engine, self.session_id, self.table_name
self.engine, self.session_id, self.table_name, self.schema_name
)

def sync_messages(self) -> None:
Expand Down
48 changes: 38 additions & 10 deletions src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ async def ainit_vectorstore_table(
self,
table_name: str,
vector_size: int,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[Column] = [],
Expand All @@ -387,6 +388,8 @@ async def ainit_vectorstore_table(
Args:
table_name (str): The Postgres database table name.
vector_size (int): Vector size for the embedding model to be used.
schema_name (str): The schema name to store Postgres database table.
Default: "public".
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Expand All @@ -407,9 +410,9 @@ async def ainit_vectorstore_table(
await self._aexecute("CREATE EXTENSION IF NOT EXISTS vector")

if overwrite_existing:
await self._aexecute(f'DROP TABLE IF EXISTS "{table_name}"')
await self._aexecute(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')

query = f"""CREATE TABLE "{table_name}"(
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
"{id_column}" UUID PRIMARY KEY,
"{content_column}" TEXT NOT NULL,
"{embedding_column}" vector({vector_size}) NOT NULL"""
Expand All @@ -426,6 +429,7 @@ def init_vectorstore_table(
self,
table_name: str,
vector_size: int,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[Column] = [],
Expand All @@ -440,6 +444,8 @@ def init_vectorstore_table(
Args:
table_name (str): The Postgres database table name.
vector_size (int): Vector size for the embedding model to be used.
schema_name (str): The schema name to store Postgres database table.
Default: "public".
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Expand All @@ -458,6 +464,7 @@ def init_vectorstore_table(
self.ainit_vectorstore_table(
table_name,
vector_size,
schema_name,
content_column,
embedding_column,
metadata_columns,
Expand All @@ -468,41 +475,51 @@ def init_vectorstore_table(
)
)

async def ainit_chat_history_table(self, table_name: str) -> None:
async def ainit_chat_history_table(
self, table_name: str, schema_name: str = "public"
) -> None:
"""Create a Cloud SQL table to store chat history.

Args:
table_name (str): Table name to store chat history.
schema_name (str): Schema name to store chat history table.
Default: "public".

Returns:
None
"""
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"(
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}"(
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
data JSONB NOT NULL,
type TEXT NOT NULL
);"""
await self._aexecute(create_table_query)

def init_chat_history_table(self, table_name: str) -> None:
def init_chat_history_table(
self, table_name: str, schema_name: str = "public"
) -> None:
"""Create a Cloud SQL table to store chat history.

Args:
table_name (str): Table name to store chat history.
schema_name (str): Schema name to store chat history table.
Default: "public".

Returns:
None
"""
return self._run_as_sync(
self.ainit_chat_history_table(
table_name,
schema_name,
)
)

async def ainit_document_table(
self,
table_name: str,
schema_name: str = "public",
content_column: str = "page_content",
metadata_columns: List[Column] = [],
metadata_json_column: str = "langchain_metadata",
Expand All @@ -513,6 +530,8 @@ async def ainit_document_table(

Args:
table_name (str): The PgSQL database table name.
schema_name (str): The schema name to store PgSQL database table.
Default: "public".
content_column (str): Name of the column to store document content.
Default: "page_content".
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
Expand All @@ -526,7 +545,7 @@ async def ainit_document_table(
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
"""

query = f"""CREATE TABLE "{table_name}"(
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
{content_column} TEXT NOT NULL
"""
for column in metadata_columns:
Expand All @@ -542,6 +561,7 @@ async def ainit_document_table(
def init_document_table(
self,
table_name: str,
schema_name: str = "public",
content_column: str = "page_content",
metadata_columns: List[Column] = [],
metadata_json_column: str = "langchain_metadata",
Expand All @@ -552,6 +572,8 @@ def init_document_table(

Args:
table_name (str): The PgSQL database table name.
schema_name (str): The schema name to store PgSQL database table.
Default: "public".
content_column (str): Name of the column to store document content.
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
to create for custom metadata. Optional.
Expand All @@ -561,6 +583,7 @@ def init_document_table(
return self._run_as_sync(
self.ainit_document_table(
table_name,
schema_name,
content_column,
metadata_columns,
metadata_json_column,
Expand All @@ -571,6 +594,7 @@ def init_document_table(
async def _aload_table_schema(
self,
table_name: str,
schema_name: str = "public",
) -> Table:
"""
Load table schema from existing table in PgSQL database.
Expand All @@ -580,11 +604,15 @@ async def _aload_table_schema(
metadata = MetaData()
async with self._engine.connect() as conn:
try:
await conn.run_sync(metadata.reflect, only=[table_name])
await conn.run_sync(
metadata.reflect, schema=schema_name, only=[table_name]
)
except InvalidRequestError as e:
raise ValueError(f"Table, {table_name}, does not exist: " + str(e))
raise ValueError(
f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e)
)

table = Table(table_name, metadata)
table = Table(table_name, metadata, schema=schema_name)
# Extract the schema information
schema = []
for column in table.columns:
Expand All @@ -597,4 +625,4 @@ async def _aload_table_schema(
}
)

return metadata.tables[table_name]
return metadata.tables[f"{schema_name}.{table_name}"]
Loading