|
| 1 | +# approach from section 3.6 in https://arxiv.org/abs/2004.12832 |
| 2 | + |
| 3 | +from colbert.infra import ColBERTConfig |
| 4 | +from colbert.modeling.checkpoint import Checkpoint |
| 5 | +from pgvector.psycopg import register_vector |
| 6 | +import psycopg |
| 7 | + |
| 8 | +conn = psycopg.connect(dbname='pgvector_example', autocommit=True) |
| 9 | + |
| 10 | +conn.execute('CREATE EXTENSION IF NOT EXISTS vector') |
| 11 | +register_vector(conn) |
| 12 | + |
| 13 | +conn.execute('DROP TABLE IF EXISTS documents') |
| 14 | +conn.execute('DROP TABLE IF EXISTS document_embeddings') |
| 15 | +conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text)') |
| 16 | +conn.execute('CREATE TABLE document_embeddings (id bigserial PRIMARY KEY, document_id bigint, embedding vector(128))') |
| 17 | +conn.execute(""" |
| 18 | +CREATE OR REPLACE FUNCTION max_sim(document vector[], query vector[]) RETURNS double precision AS $$ |
| 19 | + WITH queries AS ( |
| 20 | + SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) |
| 21 | + ), |
| 22 | + documents AS ( |
| 23 | + SELECT unnest(document) AS document |
| 24 | + ), |
| 25 | + similarities AS ( |
| 26 | + SELECT query_number, 1 - (document <=> query) AS similarity FROM queries CROSS JOIN documents |
| 27 | + ), |
| 28 | + max_similarities AS ( |
| 29 | + SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number |
| 30 | + ) |
| 31 | + SELECT SUM(max_similarity) FROM max_similarities |
| 32 | +$$ LANGUAGE SQL |
| 33 | +""") |
| 34 | + |
| 35 | +config = ColBERTConfig(doc_maxlen=220, query_maxlen=32) |
| 36 | +checkpoint = Checkpoint('colbert-ir/colbertv2.0', colbert_config=config, verbose=0) |
| 37 | + |
| 38 | +input = [ |
| 39 | + 'The dog is barking', |
| 40 | + 'The cat is purring', |
| 41 | + 'The bear is growling' |
| 42 | +] |
| 43 | +doc_embeddings = checkpoint.docFromText(input, keep_dims=False) |
| 44 | +for content, embeddings in zip(input, doc_embeddings): |
| 45 | + with conn.transaction(): |
| 46 | + result = conn.execute('INSERT INTO documents (content) VALUES (%s) RETURNING id', (content,)).fetchone() |
| 47 | + params = [] |
| 48 | + for embedding in embeddings: |
| 49 | + params.extend([result[0], embedding.numpy()]) |
| 50 | + values = ', '.join(['(%s, %s)' for _ in embeddings]) |
| 51 | + conn.execute(f'INSERT INTO document_embeddings (document_id, embedding) VALUES {values}', params) |
| 52 | + |
| 53 | +conn.execute('CREATE INDEX ON document_embeddings (document_id)') |
| 54 | +conn.execute('CREATE INDEX ON document_embeddings USING hnsw (embedding vector_cosine_ops)') |
| 55 | + |
| 56 | +query = 'puppy' |
| 57 | +query_embeddings = [e.numpy() for e in checkpoint.queryFromText([query])[0]] |
| 58 | +approximate_stage = ' UNION ALL '.join(['(SELECT document_id FROM document_embeddings ORDER BY embedding <=> %s LIMIT 5)' for _ in query_embeddings]) |
| 59 | +sql = f""" |
| 60 | +WITH approximate_stage AS ( |
| 61 | + {approximate_stage} |
| 62 | +), |
| 63 | +embeddings AS ( |
| 64 | + SELECT document_id, array_agg(embedding) AS embeddings FROM document_embeddings |
| 65 | + WHERE document_id IN (SELECT DISTINCT document_id FROM approximate_stage) |
| 66 | + GROUP BY document_id |
| 67 | +) |
| 68 | +SELECT content, max_sim(embeddings, %s) AS max_sim FROM documents |
| 69 | +INNER JOIN embeddings ON embeddings.document_id = documents.id |
| 70 | +ORDER BY max_sim DESC LIMIT 10 |
| 71 | +""" |
| 72 | +params = [v for v in query_embeddings] + [query_embeddings] |
| 73 | +result = conn.execute(sql, params).fetchall() |
| 74 | +for row in result: |
| 75 | + print(row) |
0 commit comments