Skip to content

Commit 3c51f39

Browse files
committed
Added support for halfvec and sparsevec types to Psycopg 3
1 parent 5ba96ff commit 3c51f39

File tree

9 files changed

+258
-4
lines changed

9 files changed

+258
-4
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
dev-files: true
2020
- run: |
2121
cd /tmp
22-
git clone --branch v0.6.2 https://github.com/pgvector/pgvector.git
22+
git clone --branch v0.7.0 https://github.com/pgvector/pgvector.git
2323
cd pgvector
2424
make
2525
sudo make install

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.2.6 (unreleased)
2+
3+
- Added support for `halfvec` and `sparsevec` types to Psycopg 3
4+
15
## 0.2.5 (2024-02-07)
26

37
- Added literal binds support for SQLAlchemy

pgvector/psycopg/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import psycopg
22
from psycopg.types import TypeInfo
3-
from .vector import *
3+
from .halfvec import register_halfvec_info
4+
from .sparsevec import register_sparsevec_info
5+
from .vector import register_vector_info
6+
from ..utils import HalfVec, SparseVec
47

58
# TODO remove in 0.3.0
9+
from .vector import *
610
from ..utils import from_db, from_db_binary, to_db, to_db_binary
711

812
__all__ = ['register_vector']
@@ -12,7 +16,23 @@ def register_vector(context):
1216
info = TypeInfo.fetch(context, 'vector')
1317
register_vector_info(context, info)
1418

19+
info = TypeInfo.fetch(context, 'halfvec')
20+
if info is not None:
21+
register_halfvec_info(context, info)
22+
23+
info = TypeInfo.fetch(context, 'sparsevec')
24+
if info is not None:
25+
register_sparsevec_info(context, info)
26+
1527

1628
async def register_vector_async(context):
1729
info = await TypeInfo.fetch(context, 'vector')
1830
register_vector_info(context, info)
31+
32+
info = await TypeInfo.fetch(context, 'halfvec')
33+
if info is not None:
34+
register_halfvec_info(context, info)
35+
36+
info = await TypeInfo.fetch(context, 'sparsevec')
37+
if info is not None:
38+
register_sparsevec_info(context, info)

pgvector/psycopg/halfvec.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from psycopg.adapt import Loader, Dumper
2+
from psycopg.pq import Format
3+
from ..utils import HalfVec
4+
5+
6+
class HalfVecDumper(Dumper):
7+
8+
format = Format.TEXT
9+
10+
def dump(self, obj):
11+
return obj.to_db().encode('utf8')
12+
13+
14+
class HalfVecBinaryDumper(HalfVecDumper):
15+
16+
format = Format.BINARY
17+
18+
def dump(self, obj):
19+
return obj.to_db_binary()
20+
21+
22+
class HalfVecLoader(Loader):
23+
24+
format = Format.TEXT
25+
26+
def load(self, data):
27+
if data is None:
28+
return None
29+
if isinstance(data, memoryview):
30+
data = bytes(data)
31+
return HalfVec.from_db(data.decode('utf8'))
32+
33+
34+
class HalfVecBinaryLoader(HalfVecLoader):
35+
36+
format = Format.BINARY
37+
38+
def load(self, data):
39+
if data is None:
40+
return None
41+
if isinstance(data, memoryview):
42+
data = bytes(data)
43+
return HalfVec.from_db_binary(data)
44+
45+
46+
def register_halfvec_info(context, info):
47+
info.register(context)
48+
49+
# add oid to anonymous class for set_types
50+
text_dumper = type('', (HalfVecDumper,), {'oid': info.oid})
51+
binary_dumper = type('', (HalfVecBinaryDumper,), {'oid': info.oid})
52+
53+
adapters = context.adapters
54+
adapters.register_dumper(HalfVec, text_dumper)
55+
adapters.register_dumper(HalfVec, binary_dumper)
56+
adapters.register_loader(info.oid, HalfVecLoader)
57+
adapters.register_loader(info.oid, HalfVecBinaryLoader)

pgvector/psycopg/sparsevec.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from psycopg.adapt import Loader, Dumper
2+
from psycopg.pq import Format
3+
from ..utils import SparseVec
4+
5+
6+
class SparseVecDumper(Dumper):
7+
8+
format = Format.TEXT
9+
10+
def dump(self, obj):
11+
return obj.to_db().encode('utf8')
12+
13+
14+
class SparseVecBinaryDumper(SparseVecDumper):
15+
16+
format = Format.BINARY
17+
18+
def dump(self, obj):
19+
return obj.to_db_binary()
20+
21+
22+
class SparseVecLoader(Loader):
23+
24+
format = Format.TEXT
25+
26+
def load(self, data):
27+
if data is None:
28+
return None
29+
if isinstance(data, memoryview):
30+
data = bytes(data)
31+
return SparseVec.from_db(data.decode('utf8'))
32+
33+
34+
class SparseVecBinaryLoader(SparseVecLoader):
35+
36+
format = Format.BINARY
37+
38+
def load(self, data):
39+
if data is None:
40+
return None
41+
if isinstance(data, memoryview):
42+
data = bytes(data)
43+
return SparseVec.from_db_binary(data)
44+
45+
46+
def register_sparsevec_info(context, info):
47+
info.register(context)
48+
49+
# add oid to anonymous class for set_types
50+
text_dumper = type('', (SparseVecDumper,), {'oid': info.oid})
51+
binary_dumper = type('', (SparseVecBinaryDumper,), {'oid': info.oid})
52+
53+
adapters = context.adapters
54+
adapters.register_dumper(SparseVec, text_dumper)
55+
adapters.register_dumper(SparseVec, binary_dumper)
56+
adapters.register_loader(info.oid, SparseVecLoader)
57+
adapters.register_loader(info.oid, SparseVecBinaryLoader)

pgvector/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
from .halfvec import *
2+
from .sparsevec import *
13
from .vector import *

pgvector/utils/halfvec.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from struct import pack, unpack_from
2+
3+
4+
class HalfVec:
5+
def __init__(self, value):
6+
# TODO support np.array
7+
if not isinstance(value, (list, tuple)):
8+
raise ValueError('expected list or tuple')
9+
10+
self.value = value
11+
12+
def to_list(self):
13+
return list(self.value)
14+
15+
def to_db(self):
16+
return '[' + ','.join([str(float(v)) for v in self.value]) + ']'
17+
18+
def to_db_binary(self):
19+
return pack(f'>HH{len(self.value)}e', len(self.value), 0, *self.value)
20+
21+
def from_db(value):
22+
return HalfVec([float(v) for v in value[1:-1].split(',')])
23+
24+
def from_db_binary(value):
25+
dim, unused = unpack_from('>HH', value)
26+
return HalfVec(unpack_from(f'>{dim}e', value, 4))
27+
28+
def __repr__(self):
29+
return f'HalfVec({self.value})'

pgvector/utils/sparsevec.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from struct import pack, unpack_from
2+
3+
4+
class SparseVec:
5+
def __init__(self, dim, indices, values):
6+
self.dim = dim
7+
self.indices = indices
8+
self.values = values
9+
10+
def from_dense(value):
11+
dim = len(value)
12+
indices = [i for i, v in enumerate(value) if v != 0]
13+
values = [value[i] for i in indices]
14+
return SparseVec(dim, indices, values)
15+
16+
def to_dense(self):
17+
vec = [0] * self.dim
18+
for i, v in zip(self.indices, self.values):
19+
vec[i] = v
20+
return vec
21+
22+
def to_db(self):
23+
return '{' + ','.join([f'{i + 1}:{v}' for i, v in zip(self.indices, self.values)]) + '}/' + str(self.dim)
24+
25+
def to_db_binary(self):
26+
nnz = len(self.indices)
27+
return pack(f'>iii{nnz}i{nnz}f', self.dim, nnz, 0, *self.indices, *self.values)
28+
29+
def from_db(value):
30+
elements, dim = value.split('/')
31+
indices = []
32+
values = []
33+
for e in elements[1:-1].split(','):
34+
i, v = e.split(':')
35+
indices.append(int(i) - 1)
36+
values.append(float(v))
37+
return SparseVec(int(dim), indices, values)
38+
39+
def from_db_binary(value):
40+
dim, nnz, unused = unpack_from('>iii', value)
41+
indices = unpack_from(f'>{nnz}i', value, 12)
42+
values = unpack_from(f'>{nnz}f', value, 12 + nnz * 4)
43+
return SparseVec(int(dim), indices, values)
44+
45+
def __repr__(self):
46+
return f'SparseVec({self.dim}, {self.indices}, {self.values})'

tests/test_psycopg.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector.psycopg import register_vector, register_vector_async
2+
from pgvector.psycopg import register_vector, register_vector_async, HalfVec, SparseVec
33
import psycopg
44
import pytest
55

@@ -27,7 +27,7 @@ def test_works(self):
2727

2828
def test_binary_format(self):
2929
embedding = np.array([1.5, 2, 3])
30-
res = conn.execute('SELECT %b::vector', (embedding,)).fetchone()[0]
30+
res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0]
3131
assert np.array_equal(res, embedding)
3232

3333
def test_text_format(self):
@@ -71,6 +71,45 @@ def test_binary_copy_set_types(self):
7171
copy.set_types(['int8', 'vector'])
7272
copy.write_row([1, embedding])
7373

74+
def test_halfvec(self):
75+
conn.execute('DROP TABLE IF EXISTS half_items')
76+
conn.execute('CREATE TABLE half_items (id bigserial PRIMARY KEY, embedding halfvec(3))')
77+
78+
embedding = HalfVec([1.5, 2, 3])
79+
conn.execute('INSERT INTO half_items (embedding) VALUES (%s)', (embedding,))
80+
81+
res = conn.execute('SELECT * FROM half_items ORDER BY id').fetchall()
82+
83+
def test_halfvec_binary_format(self):
84+
embedding = HalfVec([1.5, 2, 3])
85+
res = conn.execute('SELECT %b::halfvec', (embedding,), binary=True).fetchone()[0]
86+
assert res.to_list() == [1.5, 2, 3]
87+
88+
def test_halfvec_text_format(self):
89+
embedding = HalfVec([1.5, 2, 3])
90+
res = conn.execute('SELECT %t::halfvec', (embedding,)).fetchone()[0]
91+
assert res.to_list() == [1.5, 2, 3]
92+
93+
def test_sparsevec(self):
94+
conn.execute('DROP TABLE IF EXISTS sparse_items')
95+
conn.execute('CREATE TABLE sparse_items (id bigserial PRIMARY KEY, embedding sparsevec(6))')
96+
97+
embedding = SparseVec.from_dense([0, 1.5, 0, 2, 0, 3])
98+
conn.execute('INSERT INTO sparse_items (embedding) VALUES (%s)', (embedding,))
99+
100+
res = conn.execute('SELECT * FROM sparse_items ORDER BY id').fetchall()
101+
assert res[0][1].to_dense() == [0, 1.5, 0, 2, 0, 3]
102+
103+
def test_sparsevec_binary_format(self):
104+
embedding = SparseVec.from_dense([1.5, 2, 3])
105+
res = conn.execute('SELECT %b::sparsevec', (embedding,), binary=True).fetchone()[0]
106+
assert res.to_dense() == [1.5, 2, 3]
107+
108+
def test_sparsevec_text_format(self):
109+
embedding = SparseVec.from_dense([1.5, 2, 3])
110+
res = conn.execute('SELECT %t::sparsevec', (embedding,)).fetchone()[0]
111+
assert res.to_dense() == [1.5, 2, 3]
112+
74113
def test_bit(self):
75114
res = conn.execute('SELECT %s::bit(3)', ('101',)).fetchone()[0]
76115
assert res == '101'

0 commit comments

Comments
 (0)