Skip to content
Prev Previous commit
Next Next commit
fix: improve query execution logic in postgres.py
Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com>
  • Loading branch information
YassinNouh21 committed Apr 9, 2025
commit 246d6a671848a5ae7d15ca1e6c58eb2c9821c8aa
Original file line number Diff line number Diff line change
Expand Up @@ -533,20 +533,12 @@ def retrieve_online_documents_v2(
and feature.name in requested_features
]

for feature in table.features:
if (
feature.dtype.to_value_type().value == 2
and feature.name in requested_features
): # 2 is STRING
string_fields.append(feature.name)

table_name = _table_id(config.project, table)

with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
# Case 1: Hybrid Search (vector + text)
if embedding is not None and query_string is not None and string_fields:
# Case 1: Hybrid Search (vector + text)
tsquery_str = " & ".join(query_string.split())

query = sql.SQL(
"""
SELECT
Expand All @@ -568,12 +560,10 @@ def retrieve_online_documents_v2(
table_name=sql.Identifier(table_name),
top_k=sql.Literal(top_k),
)
params = (embedding, tsquery_str, string_fields, tsquery_str)

cur.execute(query, (embedding, tsquery_str, string_fields, tsquery_str))
rows = cur.fetchall()

# Case 2: Vector Search Only
elif embedding is not None:
# Case 2: Vector Search Only
query = sql.SQL(
"""
SELECT
Expand All @@ -594,12 +584,10 @@ def retrieve_online_documents_v2(
table_name=sql.Identifier(table_name),
top_k=sql.Literal(top_k),
)
params = (embedding,)

cur.execute(query, (embedding,))
rows = cur.fetchall()

# Case 3: Text Search Only
elif query_string is not None and string_fields:
# Case 3: Text Search Only
tsquery_str = " & ".join(query_string.split())
query = sql.SQL(
"""
Expand Down Expand Up @@ -628,17 +616,16 @@ def retrieve_online_documents_v2(
table_name=sql.Identifier(table_name),
top_k=sql.Literal(top_k),
)

cur.execute(
query, (tsquery_str, string_fields, tsquery_str, requested_features)
)
rows = cur.fetchall()
params = (tsquery_str, string_fields, tsquery_str, requested_features)

else:
raise ValueError(
"Either vector_enabled must be True for embedding search or string fields must be available for query_string search"
)

cur.execute(query, params)
rows = cur.fetchall()

# Group by entity_key to build feature records
entities_dict: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {
Expand Down