Skip to content

Commit eb65401

Browse files
committed
Added ColBERT example for approximate search - #123 [skip ci]
1 parent 1901b9c commit eb65401

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

examples/colbert/approximate.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# approach from section 3.6 in https://arxiv.org/abs/2004.12832
2+
3+
from colbert.infra import ColBERTConfig
4+
from colbert.modeling.checkpoint import Checkpoint
5+
from pgvector.psycopg import register_vector
6+
import psycopg
7+
8+
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
9+
10+
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
11+
register_vector(conn)
12+
13+
conn.execute('DROP TABLE IF EXISTS documents')
14+
conn.execute('DROP TABLE IF EXISTS document_embeddings')
15+
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text)')
16+
conn.execute('CREATE TABLE document_embeddings (id bigserial PRIMARY KEY, document_id bigint, embedding vector(128))')
17+
conn.execute("""
18+
CREATE OR REPLACE FUNCTION max_sim(document vector[], query vector[]) RETURNS double precision AS $$
19+
WITH queries AS (
20+
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
21+
),
22+
documents AS (
23+
SELECT unnest(document) AS document
24+
),
25+
similarities AS (
26+
SELECT query_number, 1 - (document <=> query) AS similarity FROM queries CROSS JOIN documents
27+
),
28+
max_similarities AS (
29+
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
30+
)
31+
SELECT SUM(max_similarity) FROM max_similarities
32+
$$ LANGUAGE SQL
33+
""")
34+
35+
config = ColBERTConfig(doc_maxlen=220, query_maxlen=32)
36+
checkpoint = Checkpoint('colbert-ir/colbertv2.0', colbert_config=config, verbose=0)
37+
38+
input = [
39+
'The dog is barking',
40+
'The cat is purring',
41+
'The bear is growling'
42+
]
43+
doc_embeddings = checkpoint.docFromText(input, keep_dims=False)
44+
for content, embeddings in zip(input, doc_embeddings):
45+
with conn.transaction():
46+
result = conn.execute('INSERT INTO documents (content) VALUES (%s) RETURNING id', (content,)).fetchone()
47+
params = []
48+
for embedding in embeddings:
49+
params.extend([result[0], embedding.numpy()])
50+
values = ', '.join(['(%s, %s)' for _ in embeddings])
51+
conn.execute(f'INSERT INTO document_embeddings (document_id, embedding) VALUES {values}', params)
52+
53+
conn.execute('CREATE INDEX ON document_embeddings (document_id)')
54+
conn.execute('CREATE INDEX ON document_embeddings USING hnsw (embedding vector_cosine_ops)')
55+
56+
query = 'puppy'
57+
query_embeddings = [e.numpy() for e in checkpoint.queryFromText([query])[0]]
58+
approximate_stage = ' UNION ALL '.join(['(SELECT document_id FROM document_embeddings ORDER BY embedding <=> %s LIMIT 5)' for _ in query_embeddings])
59+
sql = f"""
60+
WITH approximate_stage AS (
61+
{approximate_stage}
62+
),
63+
embeddings AS (
64+
SELECT document_id, array_agg(embedding) AS embeddings FROM document_embeddings
65+
WHERE document_id IN (SELECT DISTINCT document_id FROM approximate_stage)
66+
GROUP BY document_id
67+
)
68+
SELECT content, max_sim(embeddings, %s) AS max_sim FROM documents
69+
INNER JOIN embeddings ON embeddings.document_id = documents.id
70+
ORDER BY max_sim DESC LIMIT 10
71+
"""
72+
params = [v for v in query_embeddings] + [query_embeddings]
73+
result = conn.execute(sql, params).fetchall()
74+
for row in result:
75+
print(row)

0 commit comments

Comments
 (0)