11import openai
2- from pgvector .sqlalchemy import Vector
2+ from pgvector .psycopg import register_vector
3+ import psycopg
34from sentence_transformers import SentenceTransformer
4- from sqlalchemy import create_engine , insert , select , text , Integer , String , Text
5- from sqlalchemy .orm import declarative_base , mapped_column , Session
65
7- engine = create_engine ('postgresql+psycopg://localhost/pgvector_example' )
8- with engine .connect () as conn :
9- conn .execute (text ('CREATE EXTENSION IF NOT EXISTS vector' ))
10- conn .commit ()
6+ conn = psycopg .connect (dbname = 'pgvector_example' , autocommit = True )
117
12- Base = declarative_base ()
8+ conn .execute ('CREATE EXTENSION IF NOT EXISTS vector' )
9+ register_vector (conn )
1310
14-
15- class Document (Base ):
16- __tablename__ = 'document'
17-
18- id = mapped_column (Integer , primary_key = True )
19- content = mapped_column (Text )
20- embedding = mapped_column (Vector (1536 ))
21-
22-
23- Base .metadata .drop_all (engine )
24- Base .metadata .create_all (engine )
11+ conn .execute ('DROP TABLE IF EXISTS document' )
12+ conn .execute ('CREATE TABLE document (id bigserial PRIMARY KEY, content text, embedding vector(1536))' )
2513
2614input = [
2715 'The dog is barking' ,
@@ -30,12 +18,11 @@ class Document(Base):
3018]
3119
3220embeddings = [v ['embedding' ] for v in openai .Embedding .create (input = input , model = 'text-embedding-ada-002' )['data' ]]
33- documents = [dict (content = input [i ], embedding = embedding ) for i , embedding in enumerate (embeddings )]
3421
35- session = Session ( engine )
36- session .execute (insert ( Document ), documents )
22+ for content , embedding in zip ( input , embeddings ):
23+ conn .execute ('INSERT INTO document (content, embedding) VALUES (%s, %s)' , ( content , embedding ) )
3724
38- doc = session . get ( Document , 1 )
39- neighbors = session . scalars ( select ( Document ). filter ( Document . id != doc . id ). order_by ( Document . embedding . max_inner_product ( doc . embedding )). limit ( 5 ) )
25+ document_id = 2
26+ neighbors = conn . execute ( 'SELECT content FROM documents WHERE id != %( id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5' , { 'id' : document_id }). fetchall ( )
4027for neighbor in neighbors :
41- print (neighbor . content )
28+ print (neighbor [ 0 ] )
0 commit comments