Skip to content

Commit 0e26fe1

Browse files
committed
Converted OpenAI example to psycopg [skip ci]
1 parent ae3085a commit 0e26fe1

1 file changed

Lines changed: 12 additions & 25 deletions

File tree

examples/openai_embeddings.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,15 @@
11
import openai
2-
from pgvector.sqlalchemy import Vector
2+
from pgvector.psycopg import register_vector
3+
import psycopg
34
from sentence_transformers import SentenceTransformer
4-
from sqlalchemy import create_engine, insert, select, text, Integer, String, Text
5-
from sqlalchemy.orm import declarative_base, mapped_column, Session
65

7-
engine = create_engine('postgresql+psycopg://localhost/pgvector_example')
8-
with engine.connect() as conn:
9-
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
10-
conn.commit()
6+
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
117

12-
Base = declarative_base()
8+
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
9+
register_vector(conn)
1310

14-
15-
class Document(Base):
16-
__tablename__ = 'document'
17-
18-
id = mapped_column(Integer, primary_key=True)
19-
content = mapped_column(Text)
20-
embedding = mapped_column(Vector(1536))
21-
22-
23-
Base.metadata.drop_all(engine)
24-
Base.metadata.create_all(engine)
11+
conn.execute('DROP TABLE IF EXISTS document')
12+
conn.execute('CREATE TABLE document (id bigserial PRIMARY KEY, content text, embedding vector(1536))')
2513

2614
input = [
2715
'The dog is barking',
@@ -30,12 +18,11 @@ class Document(Base):
3018
]
3119

3220
embeddings = [v['embedding'] for v in openai.Embedding.create(input=input, model='text-embedding-ada-002')['data']]
33-
documents = [dict(content=input[i], embedding=embedding) for i, embedding in enumerate(embeddings)]
3421

35-
session = Session(engine)
36-
session.execute(insert(Document), documents)
22+
for content, embedding in zip(input, embeddings):
23+
conn.execute('INSERT INTO document (content, embedding) VALUES (%s, %s)', (content, embedding))
3724

38-
doc = session.get(Document, 1)
39-
neighbors = session.scalars(select(Document).filter(Document.id != doc.id).order_by(Document.embedding.max_inner_product(doc.embedding)).limit(5))
25+
document_id = 2
26+
neighbors = conn.execute('SELECT content FROM documents WHERE id != %(id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5', {'id': document_id}).fetchall()
4027
for neighbor in neighbors:
41-
print(neighbor.content)
28+
print(neighbor[0])

0 commit comments

Comments
 (0)