diff --git a/CHANGELOG.md b/CHANGELOG.md index 5539b0ee..b62c7d1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.9.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.8.0...v0.9.0) (2024-09-05) + + +### Features + +* Add support for custom schema names ([#191](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/191)) ([1e0566a](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/1e0566af98bf24c711315a791336ba212d240acd)) + ## [0.8.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.7.0...v0.8.0) (2024-09-04) diff --git a/src/langchain_google_cloud_sql_pg/chat_message_history.py b/src/langchain_google_cloud_sql_pg/chat_message_history.py index 4ce9f5f0..0150fa63 100644 --- a/src/langchain_google_cloud_sql_pg/chat_message_history.py +++ b/src/langchain_google_cloud_sql_pg/chat_message_history.py @@ -24,10 +24,13 @@ async def _aget_messages( - engine: PostgresEngine, session_id: str, table_name: str + engine: PostgresEngine, + session_id: str, + table_name: str, + schema_name: str = "public", ) -> List[BaseMessage]: """Retrieve the messages from PostgreSQL.""" - query = f"""SELECT data, type FROM "{table_name}" WHERE session_id = :session_id ORDER BY id;""" + query = f"""SELECT data, type FROM "{schema_name}"."{table_name}" WHERE session_id = :session_id ORDER BY id;""" results = await engine._afetch(query, {"session_id": session_id}) if not results: return [] @@ -49,6 +52,7 @@ def __init__( session_id: str, table_name: str, messages: List[BaseMessage], + schema_name: str = "public", ): """PostgresChatMessageHistory constructor. @@ -58,6 +62,7 @@ def __init__( session_id (str): Retrieve the table content with this session ID. table_name (str): Table name that stores the chat message history. messages (List[BaseMessage]): Messages to store. + schema_name (str, optional): Database schema name of the chat message history table. Defaults to "public". Raises: Exception: If constructor is directly called by the user. @@ -70,6 +75,7 @@ def __init__( self.session_id = session_id self.table_name = table_name self.messages = messages + self.schema_name = schema_name @classmethod async def create( @@ -77,6 +83,7 @@ async def create( engine: PostgresEngine, session_id: str, table_name: str, + schema_name: str = "public", ) -> PostgresChatMessageHistory: """Create a new PostgresChatMessageHistory instance. @@ -84,6 +91,7 @@ async def create( engine (PostgresEngine): Postgres engine to use. session_id (str): Retrieve the table content with this session ID. table_name (str): Table name that stores the chat message history. + schema_name (str, optional): Schema name for the chat message history table. Defaults to "public". Raises: IndexError: If the table provided does not contain required schema. @@ -91,25 +99,27 @@ async def create( Returns: PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory. """ - table_schema = await engine._aload_table_schema(table_name) + table_schema = await engine._aload_table_schema(table_name, schema_name) column_names = table_schema.columns.keys() required_columns = ["id", "session_id", "data", "type"] if not (all(x in column_names for x in required_columns)): raise IndexError( - f"Table '{table_name}' has incorrect schema. Got " + f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got " f"column names '{column_names}' but required column names " f"'{required_columns}'.\nPlease create table with following schema:" - f"\nCREATE TABLE {table_name} (" + f"\nCREATE TABLE {schema_name}.{table_name} (" "\n id INT AUTO_INCREMENT PRIMARY KEY," "\n session_id TEXT NOT NULL," "\n data JSON NOT NULL," "\n type TEXT NOT NULL" "\n);" ) - messages = await _aget_messages(engine, session_id, table_name) - return cls(cls.__create_key, engine, session_id, table_name, messages) + messages = await _aget_messages(engine, session_id, table_name, schema_name) + return cls( + cls.__create_key, engine, session_id, table_name, messages, schema_name + ) @classmethod def create_sync( @@ -117,6 +127,7 @@ def create_sync( engine: PostgresEngine, session_id: str, table_name: str, + schema_name: str = "public", ) -> PostgresChatMessageHistory: """Create a new PostgresChatMessageHistory instance. @@ -124,6 +135,7 @@ def create_sync( engine (PostgresEngine): Postgres engine to use. session_id (str): Retrieve the table content with this session ID. table_name (str): Table name that stores the chat message history. + schema_name (str, optional): Database schema name for the chat message history table. Defaults to "public". Raises: IndexError: If the table provided does not contain required schema. @@ -131,12 +143,12 @@ def create_sync( Returns: PostgresChatMessageHistory: A newly created instance of PostgresChatMessageHistory. """ - coro = cls.create(engine, session_id, table_name) + coro = cls.create(engine, session_id, table_name, schema_name) return engine._run_as_sync(coro) async def aadd_message(self, message: BaseMessage) -> None: """Append the message to the record in PostgreSQL""" - query = f"""INSERT INTO "{self.table_name}"(session_id, data, type) + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type) VALUES (:session_id, :data, :type); """ await self.engine._aexecute( @@ -148,7 +160,7 @@ async def aadd_message(self, message: BaseMessage) -> None: }, ) self.messages = await _aget_messages( - self.engine, self.session_id, self.table_name + self.engine, self.session_id, self.table_name, self.schema_name ) def add_message(self, message: BaseMessage) -> None: @@ -166,7 +178,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None: async def aclear(self) -> None: """Clear session memory from PostgreSQL""" - query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;""" + query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;""" await self.engine._aexecute(query, {"session_id": self.session_id}) self.messages = [] @@ -177,7 +189,7 @@ def clear(self) -> None: async def async_messages(self) -> None: """Retrieve the messages from Postgres.""" self.messages = await _aget_messages( - self.engine, self.session_id, self.table_name + self.engine, self.session_id, self.table_name, self.schema_name ) def sync_messages(self) -> None: diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index da284af9..dc3137cd 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -373,6 +373,7 @@ async def ainit_vectorstore_table( self, table_name: str, vector_size: int, + schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[Column] = [], @@ -387,6 +388,8 @@ async def ainit_vectorstore_table( Args: table_name (str): The Postgres database table name. vector_size (int): Vector size for the embedding model to be used. + schema_name (str): The schema name to store Postgres database table. + Default: "public". content_column (str): Name of the column to store document content. Default: "page_content". embedding_column (str) : Name of the column to store vector embeddings. @@ -407,9 +410,9 @@ async def ainit_vectorstore_table( await self._aexecute("CREATE EXTENSION IF NOT EXISTS vector") if overwrite_existing: - await self._aexecute(f'DROP TABLE IF EXISTS "{table_name}"') + await self._aexecute(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') - query = f"""CREATE TABLE "{table_name}"( + query = f"""CREATE TABLE "{schema_name}"."{table_name}"( "{id_column}" UUID PRIMARY KEY, "{content_column}" TEXT NOT NULL, "{embedding_column}" vector({vector_size}) NOT NULL""" @@ -426,6 +429,7 @@ def init_vectorstore_table( self, table_name: str, vector_size: int, + schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[Column] = [], @@ -440,6 +444,8 @@ def init_vectorstore_table( Args: table_name (str): The Postgres database table name. vector_size (int): Vector size for the embedding model to be used. + schema_name (str): The schema name to store Postgres database table. + Default: "public". content_column (str): Name of the column to store document content. Default: "page_content". embedding_column (str) : Name of the column to store vector embeddings. @@ -458,6 +464,7 @@ def init_vectorstore_table( self.ainit_vectorstore_table( table_name, vector_size, + schema_name, content_column, embedding_column, metadata_columns, @@ -468,16 +475,20 @@ def init_vectorstore_table( ) ) - async def ainit_chat_history_table(self, table_name: str) -> None: + async def ainit_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: """Create a Cloud SQL table to store chat history. Args: table_name (str): Table name to store chat history. + schema_name (str): Schema name to store chat history table. + Default: "public". Returns: None """ - create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"( + create_table_query = f"""CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}"( id SERIAL PRIMARY KEY, session_id TEXT NOT NULL, data JSONB NOT NULL, @@ -485,11 +496,15 @@ async def ainit_chat_history_table(self, table_name: str) -> None: );""" await self._aexecute(create_table_query) - def init_chat_history_table(self, table_name: str) -> None: + def init_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: """Create a Cloud SQL table to store chat history. Args: table_name (str): Table name to store chat history. + schema_name (str): Schema name to store chat history table. + Default: "public". Returns: None @@ -497,12 +512,14 @@ def init_chat_history_table(self, table_name: str) -> None: return self._run_as_sync( self.ainit_chat_history_table( table_name, + schema_name, ) ) async def ainit_document_table( self, table_name: str, + schema_name: str = "public", content_column: str = "page_content", metadata_columns: List[Column] = [], metadata_json_column: str = "langchain_metadata", @@ -513,6 +530,8 @@ async def ainit_document_table( Args: table_name (str): The PgSQL database table name. + schema_name (str): The schema name to store PgSQL database table. + Default: "public". content_column (str): Name of the column to store document content. Default: "page_content". metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns @@ -526,7 +545,7 @@ async def ainit_document_table( :class:`DuplicateTableError `: if table already exists. """ - query = f"""CREATE TABLE "{table_name}"( + query = f"""CREATE TABLE "{schema_name}"."{table_name}"( {content_column} TEXT NOT NULL """ for column in metadata_columns: @@ -542,6 +561,7 @@ async def ainit_document_table( def init_document_table( self, table_name: str, + schema_name: str = "public", content_column: str = "page_content", metadata_columns: List[Column] = [], metadata_json_column: str = "langchain_metadata", @@ -552,6 +572,8 @@ def init_document_table( Args: table_name (str): The PgSQL database table name. + schema_name (str): The schema name to store PgSQL database table. + Default: "public". content_column (str): Name of the column to store document content. metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns to create for custom metadata. Optional. @@ -561,6 +583,7 @@ def init_document_table( return self._run_as_sync( self.ainit_document_table( table_name, + schema_name, content_column, metadata_columns, metadata_json_column, @@ -571,6 +594,7 @@ def init_document_table( async def _aload_table_schema( self, table_name: str, + schema_name: str = "public", ) -> Table: """ Load table schema from existing table in PgSQL database. @@ -580,11 +604,15 @@ async def _aload_table_schema( metadata = MetaData() async with self._engine.connect() as conn: try: - await conn.run_sync(metadata.reflect, only=[table_name]) + await conn.run_sync( + metadata.reflect, schema=schema_name, only=[table_name] + ) except InvalidRequestError as e: - raise ValueError(f"Table, {table_name}, does not exist: " + str(e)) + raise ValueError( + f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e) + ) - table = Table(table_name, metadata) + table = Table(table_name, metadata, schema=schema_name) # Extract the schema information schema = [] for column in table.columns: @@ -597,4 +625,4 @@ async def _aload_table_schema( } ) - return metadata.tables[table_name] + return metadata.tables[f"{schema_name}.{table_name}"] diff --git a/src/langchain_google_cloud_sql_pg/loader.py b/src/langchain_google_cloud_sql_pg/loader.py index 92dd7941..39ee8935 100644 --- a/src/langchain_google_cloud_sql_pg/loader.py +++ b/src/langchain_google_cloud_sql_pg/loader.py @@ -157,6 +157,7 @@ async def create( engine: PostgresEngine, query: Optional[str] = None, table_name: Optional[str] = None, + schema_name: str = "public", content_columns: Optional[List[str]] = None, metadata_columns: Optional[List[str]] = None, metadata_json_column: Optional[str] = None, @@ -169,6 +170,7 @@ async def create( engine (PostgresEngine):AsyncEngine with pool connection to the postgres database query (Optional[str], optional): SQL query. Defaults to None. table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". @@ -201,7 +203,7 @@ async def create( formatter = text_formatter if not query: - query = f'SELECT * FROM "{table_name}"' + query = f'SELECT * FROM "{schema_name}"."{table_name}"' stmt = sqlalchemy.text(query) async with engine._engine.connect() as connection: @@ -250,6 +252,7 @@ def create_sync( engine: PostgresEngine, query: Optional[str] = None, table_name: Optional[str] = None, + schema_name: str = "public", content_columns: Optional[List[str]] = None, metadata_columns: Optional[List[str]] = None, metadata_json_column: Optional[str] = None, @@ -262,6 +265,7 @@ def create_sync( engine (PostgresEngine):AsyncEngine with pool connection to the postgres database query (Optional[str], optional): SQL query. Defaults to None. table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". @@ -275,6 +279,7 @@ def create_sync( engine, query, table_name, + schema_name, content_columns, metadata_columns, metadata_json_column, @@ -344,6 +349,7 @@ def __init__( engine: PostgresEngine, table_name: str, content_column: str, + schema_name: str = "public", metadata_columns: List[str] = [], metadata_json_column: Optional[str] = None, ): @@ -354,6 +360,7 @@ def __init__( engine (PostgresEngine): AsyncEngine with pool connection to the postgres database table_name (Optional[str], optional): Name of table to query. Defaults to None. content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + schema_name (str, optional): Database schema name of the table. Defaults to "public". metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". @@ -367,6 +374,7 @@ def __init__( self.engine = engine self.table_name = table_name self.content_column = content_column + self.schema_name = schema_name self.metadata_columns = metadata_columns self.metadata_json_column = metadata_json_column @@ -375,6 +383,7 @@ async def create( cls, engine: PostgresEngine, table_name: str, + schema_name: str = "public", content_column: str = DEFAULT_CONTENT_COL, metadata_columns: List[str] = [], metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, @@ -384,6 +393,7 @@ async def create( Args: engine (PostgresEngine):AsyncEngine with pool connection to the postgres database table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". @@ -391,7 +401,7 @@ async def create( Returns: PostgresDocumentSaver """ - table_schema = await engine._aload_table_schema(table_name) + table_schema = await engine._aload_table_schema(table_name, schema_name) column_names = table_schema.columns.keys() if content_column not in column_names: raise ValueError(f"Content column, {content_column}, does not exist.") @@ -423,6 +433,7 @@ async def create( engine, table_name, content_column, + schema_name, metadata_columns, metadata_json_column, ) @@ -432,6 +443,7 @@ def create_sync( cls, engine: PostgresEngine, table_name: str, + schema_name: str = "public", content_column: str = DEFAULT_CONTENT_COL, metadata_columns: List[str] = [], metadata_json_column: str = DEFAULT_METADATA_COL, @@ -441,6 +453,7 @@ def create_sync( Args: engine (PostgresEngine):AsyncEngine with pool connection to the postgres database table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column. metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata". @@ -451,6 +464,7 @@ def create_sync( coro = cls.create( engine, table_name, + schema_name, content_column, metadata_columns, metadata_json_column, @@ -478,7 +492,7 @@ async def aadd_documents(self, docs: List[Document]) -> None: row[key] = json.dumps(value) # Create list of column names - insert_stmt = f'INSERT INTO "{self.table_name}"({self.content_column}' + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.content_column}' values_stmt = f"VALUES (:{self.content_column}" # Add metadata @@ -536,7 +550,7 @@ async def adelete(self, docs: List[Document]) -> None: where_conditions_list.append(f"{key} = :{key}") where_conditions = " AND ".join(where_conditions_list) - stmt = f'DELETE FROM "{self.table_name}" WHERE {where_conditions};' + stmt = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {where_conditions};' values = {} for key, value in row.items(): if type(value) is int: diff --git a/src/langchain_google_cloud_sql_pg/vectorstore.py b/src/langchain_google_cloud_sql_pg/vectorstore.py index a4a0b53a..964e9df9 100644 --- a/src/langchain_google_cloud_sql_pg/vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/vectorstore.py @@ -47,6 +47,7 @@ def __init__( engine: PostgresEngine, embedding_service: Embeddings, table_name: str, + schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], @@ -64,6 +65,7 @@ def __init__( engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. embedding_service (Embeddings): Text embedding model to use. table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_column (str): Column that represent a Document’s page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. @@ -87,6 +89,7 @@ def __init__( self.engine = engine self.embedding_service = embedding_service self.table_name = table_name + self.schema_name = schema_name self.content_column = content_column self.embedding_column = embedding_column self.metadata_columns = metadata_columns @@ -104,6 +107,7 @@ async def create( engine: PostgresEngine, embedding_service: Embeddings, table_name: str, + schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], @@ -122,6 +126,7 @@ async def create( engine (PostgresEngine): Connection pool engine for managing connections to Cloud SQL for PostgreSQL database. embedding_service (Embeddings): Text embedding model to use. table_name (str): Name of an existing table or table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_column (str): Column that represent a Document's page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. @@ -142,7 +147,7 @@ async def create( "Can not use both metadata_columns and ignore_metadata_columns." ) # Get field type information - stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'" + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" results = await engine._afetch(stmt) columns = {} for field in results: @@ -190,6 +195,7 @@ async def create( engine, embedding_service, table_name, + schema_name, content_column, embedding_column, metadata_columns, @@ -208,6 +214,7 @@ def create_sync( engine: PostgresEngine, embedding_service: Embeddings, table_name: str, + schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", metadata_columns: List[str] = [], @@ -226,6 +233,7 @@ def create_sync( engine (PostgresEngine): Connection pool engine for managing connections to Cloud SQL for PostgreSQL database. embedding_service (Embeddings): Text embedding model to use. table_name (str): Name of an existing table or table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". content_column (str): Column that represent a Document's page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (List[str]): Column(s) that represent a document's metadata. @@ -245,6 +253,7 @@ def create_sync( engine, embedding_service, table_name, + schema_name, content_column, embedding_column, metadata_columns, @@ -283,7 +292,7 @@ async def _aadd_embeddings( if len(self.metadata_columns) > 0 else "" ) - insert_stmt = f'INSERT INTO "{self.table_name}"({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}' + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}' values = {"id": id, "content": content, "embedding": str(embedding)} values_stmt = "VALUES (:id, :content, :embedding" @@ -369,7 +378,7 @@ async def adelete( return False id_list = ", ".join([f"'{id}'" for id in ids]) - query = f'DELETE FROM "{self.table_name}" WHERE {self.id_column} in ({id_list})' + query = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} in ({id_list})' await self.engine._aexecute(query) return True @@ -388,6 +397,7 @@ async def afrom_texts( # type: ignore[override] embedding: Embeddings, engine: PostgresEngine, table_name: str, + schema_name: str = "public", metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, content_column: str = "content", @@ -404,6 +414,7 @@ async def afrom_texts( # type: ignore[override] embedding (Embeddings): Text embedding model to use. engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. ids: (Optional[List[str]]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". @@ -420,6 +431,7 @@ async def afrom_texts( # type: ignore[override] engine, embedding, table_name, + schema_name, content_column, embedding_column, metadata_columns, @@ -437,6 +449,7 @@ async def afrom_documents( # type: ignore[override] embedding: Embeddings, engine: PostgresEngine, table_name: str, + schema_name: str = "public", ids: Optional[List[str]] = None, content_column: str = "content", embedding_column: str = "embedding", @@ -453,6 +466,7 @@ async def afrom_documents( # type: ignore[override] embedding (Embeddings): Text embedding model to use. engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. ids: (Optional[List[str]]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". @@ -469,6 +483,7 @@ async def afrom_documents( # type: ignore[override] engine, embedding, table_name, + schema_name, content_column, embedding_column, metadata_columns, @@ -488,6 +503,7 @@ def from_texts( # type: ignore[override] embedding: Embeddings, engine: PostgresEngine, table_name: str, + schema_name: str = "public", metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, content_column: str = "content", @@ -504,6 +520,7 @@ def from_texts( # type: ignore[override] embedding (Embeddings): Text embedding model to use. engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. ids: (Optional[List[str]]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". @@ -521,6 +538,7 @@ def from_texts( # type: ignore[override] embedding, engine, table_name, + schema_name, metadatas=metadatas, content_column=content_column, embedding_column=embedding_column, @@ -540,6 +558,7 @@ def from_documents( # type: ignore[override] embedding: Embeddings, engine: PostgresEngine, table_name: str, + schema_name: str = "public", ids: Optional[List[str]] = None, content_column: str = "content", embedding_column: str = "embedding", @@ -556,6 +575,7 @@ def from_documents( # type: ignore[override] embedding (Embeddings): Text embedding model to use. engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Database schema name of the table. Defaults to "public". metadatas (Optional[List[dict]]): List of metadatas to add to table records. ids: (Optional[List[str]]): List of IDs to add to table records. content_column (str): Column that represent a Document’s page_content. Defaults to "content". @@ -573,6 +593,7 @@ def from_documents( # type: ignore[override] embedding, engine, table_name, + schema_name, content_column=content_column, embedding_column=embedding_column, metadata_columns=metadata_columns, @@ -597,7 +618,7 @@ async def __query_collection( search_function = self.distance_strategy.search_function filter = f"WHERE {filter}" if filter else "" - stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" + stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" if self.index_query_options: query_options_stmt = f"SET LOCAL {self.index_query_options.to_string()};" results = await self.engine._afetch_with_query_options( @@ -908,7 +929,7 @@ async def aapply_vector_index( if index.name == None: index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX name = index.name - stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {name} ON "{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' + stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {name} ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' if concurrently: await self.engine._aexecute_outside_tx(stmt) else: @@ -938,7 +959,7 @@ async def is_valid_index( query = f""" SELECT tablename, indexname FROM pg_indexes - WHERE tablename = '{self.table_name}' AND indexname = '{index_name}'; + WHERE tablename = '{self.table_name}' AND schemaname = '{self.schema_name}' AND indexname = '{index_name}'; """ results = await self.engine._afetch(query) return bool(len(results) == 1) diff --git a/src/langchain_google_cloud_sql_pg/version.py b/src/langchain_google_cloud_sql_pg/version.py index 74efebbe..ba03825a 100644 --- a/src/langchain_google_cloud_sql_pg/version.py +++ b/src/langchain_google_cloud_sql_pg/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.8.0" +__version__ = "0.9.0"