Skip to content

Commit cabb4e0

Browse files
committed
Added experimental support for psycopg3
1 parent cbee348 commit cabb4e0

6 files changed

Lines changed: 120 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.1.3 (unreleased)
22

33
- Added support for asyncpg
4+
- Added experimental support for psycopg3
45

56
## 0.1.2 (2021-06-13)
67

README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Great for online recommendations :tada:
66

7-
Supports [Django](https://github.com/django/django), [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), [Psycopg 2](https://github.com/psycopg/psycopg2), and [asyncpg](https://github.com/MagicStack/asyncpg)
7+
Supports [Django](https://github.com/django/django), [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), [Psycopg 2](https://github.com/psycopg/psycopg2), [Psycopg 3](https://github.com/psycopg/psycopg3), and [asyncpg](https://github.com/MagicStack/asyncpg)
88

99
[![Build Status](https://github.com/ankane/pgvector-python/workflows/build/badge.svg?branch=master)](https://github.com/ankane/pgvector-python/actions)
1010

@@ -21,6 +21,7 @@ And follow the instructions for your database library:
2121
- [Django](#django)
2222
- [SQLAlchemy](#sqlalchemy)
2323
- [Psycopg 2](#psycopg-2)
24+
- [Psycopg 3](#psycopg-3) [experimental]
2425
- [asyncpg](#asyncpg)
2526

2627
Or check out some examples:
@@ -151,6 +152,29 @@ cur.execute('SELECT * FROM item ORDER BY factors <-> %s LIMIT 5', (factors,))
151152
cur.fetchall()
152153
```
153154

155+
## psycopg3
156+
157+
Register the vector type with your connection or cursor
158+
159+
```python
160+
from pgvector.psycopg3 import register_vector
161+
162+
register_vector(conn)
163+
```
164+
165+
Insert a vector
166+
167+
```python
168+
factors = np.array([1, 2, 3])
169+
cur.execute('INSERT INTO item (factors) VALUES (%s)', (factors,))
170+
```
171+
172+
Get the nearest neighbors to a vector
173+
174+
```python
175+
cur.execute('SELECT * FROM item ORDER BY factors <-> %s LIMIT 5', (factors,)).fetchall()
176+
```
177+
154178
## asyncpg
155179

156180
Register the vector type with your connection

pgvector/psycopg3/__init__.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
from psycopg3.adapt import Loader, Dumper
3+
from psycopg3.pq import Format
4+
from ..utils import from_db, from_db_binary, to_db, to_db_binary
5+
6+
__all__ = ['register_vector']
7+
8+
9+
class VectorDumper(Dumper):
10+
11+
format = Format.TEXT
12+
13+
def dump(self, obj):
14+
return to_db(obj).encode("utf8")
15+
16+
17+
class VectorBinaryDumper(VectorDumper):
18+
19+
format = Format.BINARY
20+
21+
def dump(self, obj):
22+
return to_db_binary(obj)
23+
24+
25+
class VectorLoader(Loader):
26+
27+
format = Format.TEXT
28+
29+
def load(self, data):
30+
if isinstance(data, memoryview):
31+
data = bytes(data)
32+
return from_db(data.decode("utf8"))
33+
34+
35+
class VectorBinaryLoader(VectorLoader):
36+
37+
format = Format.BINARY
38+
39+
def load(self, data):
40+
if isinstance(data, memoryview):
41+
data = bytes(data)
42+
return from_db_binary(data)
43+
44+
45+
def register_vector(ctx):
46+
cur = ctx.cursor() if hasattr(ctx, 'cursor') else ctx
47+
48+
try:
49+
cur.execute('SELECT NULL::vector')
50+
oid = cur.description[0][1]
51+
except psycopg3.errors.UndefinedObject:
52+
raise psycopg3.ProgrammingError('vector type not found in the database')
53+
54+
VectorDumper.register('numpy.ndarray', ctx)
55+
VectorBinaryDumper.register('numpy.ndarray', ctx)
56+
VectorLoader.register(oid, ctx)
57+
VectorBinaryLoader.register(oid, ctx)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ psycopg2
55
pytest
66
pytest-asyncio
77
SQLAlchemy
8+
git+https://github.com/psycopg/psycopg3.git#subdirectory=psycopg3

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
'pgvector.asyncpg',
1313
'pgvector.django',
1414
'pgvector.psycopg2',
15+
'pgvector.psycopg3',
1516
'pgvector.sqlalchemy',
1617
'pgvector.utils'
1718
],

tests/test_psycopg3.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
from pgvector.psycopg3 import register_vector
3+
import psycopg3
4+
5+
6+
class TestPsycopg3(object):
7+
def test_works(self):
8+
conn = psycopg3.connect('dbname=pgvector_python_test')
9+
conn.autocommit = True
10+
11+
cur = conn.cursor()
12+
cur.execute('CREATE EXTENSION IF NOT EXISTS vector')
13+
cur.execute('DROP TABLE IF EXISTS item')
14+
cur.execute('CREATE TABLE item (id bigserial primary key, factors vector(3))')
15+
16+
register_vector(cur)
17+
18+
factors = np.array([1.5, 2, 3])
19+
cur.execute("INSERT INTO item (factors) VALUES (%s), (NULL)", (factors,))
20+
21+
cur.execute("SELECT * FROM item ORDER BY id")
22+
res = cur.fetchall()
23+
assert res[0][0] == 1
24+
assert res[1][0] == 2
25+
assert np.array_equal(res[0][1], factors)
26+
assert res[0][1].dtype == np.float32
27+
assert res[1][1] is None
28+
29+
# binary format
30+
binary_res = cur.execute("SELECT %b::vector", (factors,)).fetchone()[0]
31+
assert np.array_equal(binary_res, factors)
32+
33+
# text format
34+
text_res = cur.execute("SELECT %t::vector", (factors,)).fetchone()[0]
35+
assert np.array_equal(text_res, factors)

0 commit comments

Comments
 (0)