Skip to content

Commit 866f038

Browse files
ankaneakoshelev-fhl
andcommitted
Added register_vector_async for psycopg3
Co-authored-by: Alexander Koshelev <a.koshelev@fhl.world>
1 parent 007467b commit 866f038

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.1.7 (unreleased)
22

3+
- Added `register_vector_async` for psycopg3
34
- Fixed `set_types` for psycopg3
45

56
## 0.1.6 (2022-05-22)

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ from pgvector.psycopg import register_vector
180180
register_vector(conn)
181181
```
182182

183+
For [async connections](https://www.psycopg.org/psycopg3/docs/advanced/async.html), use [unreleased]
184+
185+
```python
186+
from pgvector.psycopg import register_vector_async
187+
188+
await register_vector_async(conn)
189+
```
190+
183191
Insert a vector
184192

185193
```python

pgvector/psycopg/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def load(self, data):
4545

4646
def register_vector(context):
4747
info = TypeInfo.fetch(context, 'vector')
48+
register_vector_info(context, info)
49+
50+
51+
async def register_vector_async(context):
52+
info = await TypeInfo.fetch(context, 'vector')
53+
register_vector_info(context, info)
54+
55+
56+
def register_vector_info(context, info):
4857
if info is None:
4958
raise psycopg.ProgrammingError('vector type not found in the database')
5059
info.register(context)

tests/test_psycopg.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
2-
from pgvector.psycopg import register_vector
2+
from pgvector.psycopg import register_vector, register_vector_async
33
import psycopg
4+
import pytest
45

56
conn = psycopg.connect(dbname='pgvector_python_test')
67
conn.autocommit = True
@@ -70,3 +71,23 @@ def test_binary_copy_set_types(self):
7071
with cur.copy("COPY item (id, embedding) FROM STDIN WITH (FORMAT BINARY)") as copy:
7172
copy.set_types(['int8', 'vector'])
7273
copy.write_row([1, embedding])
74+
75+
@pytest.mark.asyncio
76+
async def test_async(self):
77+
conn = await psycopg.AsyncConnection.connect(dbname='pgvector_python_test', autocommit=True)
78+
79+
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
80+
await conn.execute('DROP TABLE IF EXISTS item')
81+
await conn.execute('CREATE TABLE item (id bigserial primary key, embedding vector(3))')
82+
83+
await register_vector_async(conn)
84+
85+
embedding = np.array([1.5, 2, 3])
86+
await conn.execute('INSERT INTO item (embedding) VALUES (%s), (NULL)', (embedding,))
87+
88+
async with conn.cursor() as cur:
89+
await cur.execute('SELECT * FROM item ORDER BY id')
90+
res = await cur.fetchall()
91+
assert np.array_equal(res[0][1], embedding)
92+
assert res[0][1].dtype == np.float32
93+
assert res[1][1] is None

0 commit comments

Comments
 (0)