Skip to content

Commit 50def9c

Browse files
committed
Simplified SparseVector construction
1 parent aba997c commit 50def9c

10 files changed

Lines changed: 100 additions & 78 deletions

File tree

examples/sparse_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def fetch_embeddings(input):
4545
]
4646
embeddings = fetch_embeddings(input)
4747
for content, embedding in zip(input, embeddings):
48-
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, SparseVector.from_dense(embedding)))
48+
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, SparseVector(embedding)))
4949

5050
query = 'forest'
5151
query_embedding = fetch_embeddings([query])[0]
52-
result = conn.execute('SELECT content FROM documents ORDER BY embedding <#> %s LIMIT 5', (SparseVector.from_dense(query_embedding),)).fetchall()
52+
result = conn.execute('SELECT content FROM documents ORDER BY embedding <#> %s LIMIT 5', (SparseVector(query_embedding),)).fetchall()
5353
for row in result:
5454
print(row[0])

pgvector/utils/sparsevec.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,54 @@
33

44

55
class SparseVector:
6-
def __init__(self, dim, indices, values):
7-
# TODO improve
8-
self._dim = int(dim)
9-
self._indices = [int(i) for i in indices]
10-
self._values = [float(v) for v in values]
6+
def __init__(self, value, dimensions=None):
7+
if value.__class__.__module__ == 'scipy.sparse._arrays':
8+
if dimensions is not None:
9+
raise ValueError('dimensions not allowed')
10+
11+
self._from_sparse(value)
12+
elif isinstance(value, dict):
13+
self._from_dict(value, dimensions)
14+
else:
15+
if dimensions is not None:
16+
raise ValueError('dimensions not allowed')
17+
18+
self._from_dense(value)
1119

1220
def __repr__(self):
13-
return f'SparseVector({self._dim}, {self._indices}, {self._values})'
21+
return f'SparseVector({self.to_dict()}, {self.dim()})'
22+
23+
def _from_dict(self, d, dim):
24+
if dim is None:
25+
raise ValueError('dimensions required')
1426

15-
@classmethod
16-
def from_dict(cls, d, dim):
1727
elements = [(i, v) for i, v in d.items()]
1828
elements.sort()
19-
indices = [int(v[0]) for v in elements]
20-
values = [float(v[1]) for v in elements]
21-
return cls(dim, indices, values)
29+
self._dim = int(dim)
30+
self._indices = [int(v[0]) for v in elements]
31+
self._values = [float(v[1]) for v in elements]
2232

23-
@classmethod
24-
def from_sparse(cls, value):
33+
def _from_sparse(self, value):
2534
value = value.tocoo()
2635

2736
if value.ndim == 1:
28-
dim = value.shape[0]
37+
self._dim = value.shape[0]
2938
elif value.ndim == 2 and value.shape[0] == 1:
30-
dim = value.shape[1]
39+
self._dim = value.shape[1]
3140
else:
3241
raise ValueError('expected ndim to be 1')
3342

3443
if hasattr(value, 'coords'):
3544
# scipy 1.13+
36-
indices = value.coords[0].tolist()
45+
self._indices = value.coords[0].tolist()
3746
else:
38-
indices = value.col.tolist()
39-
values = value.data.tolist()
40-
return cls(dim, indices, values)
47+
self._indices = value.col.tolist()
48+
self._values = value.data.tolist()
4149

42-
@classmethod
43-
def from_dense(cls, value):
44-
dim = len(value)
45-
indices = [i for i, v in enumerate(value) if v != 0]
46-
values = [float(value[i]) for i in indices]
47-
return cls(dim, indices, values)
50+
def _from_dense(self, value):
51+
self._dim = len(value)
52+
self._indices = [i for i, v in enumerate(value) if v != 0]
53+
self._values = [float(value[i]) for i in self._indices]
4854

4955
def dim(self):
5056
return self._dim
@@ -86,21 +92,30 @@ def from_text(cls, value):
8692
i, v = e.split(':', 2)
8793
indices.append(int(i) - 1)
8894
values.append(float(v))
89-
return cls(int(dim), indices, values)
95+
return cls._from_parts(int(dim), indices, values)
9096

9197
@classmethod
9298
def from_binary(cls, value):
9399
dim, nnz, unused = unpack_from('>iii', value)
94100
indices = unpack_from(f'>{nnz}i', value, 12)
95101
values = unpack_from(f'>{nnz}f', value, 12 + nnz * 4)
96-
return cls(int(dim), indices, values)
102+
return cls._from_parts(int(dim), indices, values)
103+
104+
@classmethod
105+
def _from_parts(cls, dim, indices, values):
106+
vec = cls.__new__(cls)
107+
vec._dim = dim
108+
vec._indices = indices
109+
vec._values = values
110+
return vec
97111

98112
@classmethod
99113
def _to_db(cls, value, dim=None):
100114
if value is None:
101115
return value
102116

103-
value = cls._to_db_value(value)
117+
if not isinstance(value, cls):
118+
value = cls(value)
104119

105120
if dim is not None and value.dim() != dim:
106121
raise ValueError('expected %d dimensions, not %d' % (dim, value.dim()))
@@ -112,19 +127,11 @@ def _to_db_binary(cls, value):
112127
if value is None:
113128
return value
114129

115-
value = cls._to_db_value(value)
130+
if not isinstance(value, cls):
131+
value = cls(value)
116132

117133
return value.to_binary()
118134

119-
@classmethod
120-
def _to_db_value(cls, value):
121-
if isinstance(value, cls):
122-
return value
123-
elif isinstance(value, (list, np.ndarray)):
124-
return cls.from_dense(value)
125-
else:
126-
raise ValueError('expected sparsevec')
127-
128135
@classmethod
129136
def _from_db(cls, value):
130137
if value is None or isinstance(value, cls):

tests/test_asyncpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def test_sparsevec(self):
8282

8383
await register_vector(conn)
8484

85-
embedding = SparseVector.from_dense([1.5, 2, 3])
85+
embedding = SparseVector([1.5, 2, 3])
8686
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
8787

8888
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")

tests/test_django.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ class Migration(migrations.Migration):
8888

8989

9090
def create_items():
91-
Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1])).save()
92-
Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2])).save()
93-
Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2])).save()
91+
Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])).save()
92+
Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])).save()
93+
Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])).save()
9494

9595

9696
class VectorForm(ModelForm):
@@ -208,34 +208,34 @@ def test_bit_jaccard_distance(self):
208208
# assert [v.distance for v in items] == [0, 1/3, 1]
209209

210210
def test_sparsevec(self):
211-
Item(id=1, sparse_embedding=SparseVector.from_dense([1, 2, 3])).save()
211+
Item(id=1, sparse_embedding=SparseVector([1, 2, 3])).save()
212212
item = Item.objects.get(pk=1)
213213
assert item.sparse_embedding.to_list() == [1, 2, 3]
214214

215215
def test_sparsevec_l2_distance(self):
216216
create_items()
217-
distance = L2Distance('sparse_embedding', SparseVector.from_dense([1, 1, 1]))
217+
distance = L2Distance('sparse_embedding', SparseVector([1, 1, 1]))
218218
items = Item.objects.annotate(distance=distance).order_by(distance)
219219
assert [v.id for v in items] == [1, 3, 2]
220220
assert [v.distance for v in items] == [0, 1, sqrt(3)]
221221

222222
def test_sparsevec_max_inner_product(self):
223223
create_items()
224-
distance = MaxInnerProduct('sparse_embedding', SparseVector.from_dense([1, 1, 1]))
224+
distance = MaxInnerProduct('sparse_embedding', SparseVector([1, 1, 1]))
225225
items = Item.objects.annotate(distance=distance).order_by(distance)
226226
assert [v.id for v in items] == [2, 3, 1]
227227
assert [v.distance for v in items] == [-6, -4, -3]
228228

229229
def test_sparsevec_cosine_distance(self):
230230
create_items()
231-
distance = CosineDistance('sparse_embedding', SparseVector.from_dense([1, 1, 1]))
231+
distance = CosineDistance('sparse_embedding', SparseVector([1, 1, 1]))
232232
items = Item.objects.annotate(distance=distance).order_by(distance)
233233
assert [v.id for v in items] == [1, 2, 3]
234234
assert [v.distance for v in items] == [0, 0, 0.05719095841793653]
235235

236236
def test_sparsevec_l1_distance(self):
237237
create_items()
238-
distance = L1Distance('sparse_embedding', SparseVector.from_dense([1, 1, 1]))
238+
distance = L1Distance('sparse_embedding', SparseVector([1, 1, 1]))
239239
items = Item.objects.annotate(distance=distance).order_by(distance)
240240
assert [v.id for v in items] == [1, 3, 2]
241241
assert [v.distance for v in items] == [0, 1, 3]
@@ -402,7 +402,7 @@ def test_sparesevec_form_save_missing(self):
402402
assert Item.objects.get(pk=1).sparse_embedding is None
403403

404404
def test_clean(self):
405-
item = Item(id=1, embedding=[1, 2, 3], half_embedding=[1, 2, 3], binary_embedding='101', sparse_embedding=SparseVector.from_dense([1, 2, 3]))
405+
item = Item(id=1, embedding=[1, 2, 3], half_embedding=[1, 2, 3], binary_embedding='101', sparse_embedding=SparseVector([1, 2, 3]))
406406
item.full_clean()
407407

408408
def test_get_or_create(self):

tests/test_peewee.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ class Meta:
3030

3131

3232
def create_items():
33-
Item.create(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1]))
34-
Item.create(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2]))
35-
Item.create(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2]))
33+
Item.create(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1]))
34+
Item.create(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2]))
35+
Item.create(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2]))
3636

3737

3838
class TestPeewee:
@@ -132,7 +132,7 @@ def test_sparsevec(self):
132132

133133
def test_sparsevec_l2_distance(self):
134134
create_items()
135-
distance = Item.sparse_embedding.l2_distance(SparseVector.from_dense([1, 1, 1]))
135+
distance = Item.sparse_embedding.l2_distance(SparseVector([1, 1, 1]))
136136
items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5)
137137
assert [v.id for v in items] == [1, 3, 2]
138138
assert [v.distance for v in items] == [0, 1, sqrt(3)]

tests/test_psycopg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,20 @@ def test_bit_text_format(self):
100100
assert repr(Bit(res)) == 'Bit(010100001)'
101101

102102
def test_sparsevec(self):
103-
embedding = SparseVector.from_dense([1.5, 2, 3])
103+
embedding = SparseVector([1.5, 2, 3])
104104
conn.execute('INSERT INTO psycopg_items (sparse_embedding) VALUES (%s)', (embedding,))
105105

106106
res = conn.execute('SELECT sparse_embedding FROM psycopg_items ORDER BY id').fetchone()[0]
107107
assert res.to_list() == [1.5, 2, 3]
108108

109109
def test_sparsevec_binary_format(self):
110-
embedding = SparseVector.from_dense([1.5, 0, 2, 0, 3, 0])
110+
embedding = SparseVector([1.5, 0, 2, 0, 3, 0])
111111
res = conn.execute('SELECT %b::sparsevec', (embedding,), binary=True).fetchone()[0]
112112
assert res.to_list() == [1.5, 0, 2, 0, 3, 0]
113113
assert np.array_equal(res.to_numpy(), np.array([1.5, 0, 2, 0, 3, 0]))
114114

115115
def test_sparsevec_text_format(self):
116-
embedding = SparseVector.from_dense([1.5, 0, 2, 0, 3, 0])
116+
embedding = SparseVector([1.5, 0, 2, 0, 3, 0])
117117
res = conn.execute('SELECT %t::sparsevec', (embedding,)).fetchone()[0]
118118
assert res.to_list() == [1.5, 0, 2, 0, 3, 0]
119119
assert np.array_equal(res.to_numpy(), np.array([1.5, 0, 2, 0, 3, 0]))
@@ -122,20 +122,20 @@ def test_text_copy(self):
122122
embedding = np.array([1.5, 2, 3])
123123
cur = conn.cursor()
124124
with cur.copy("COPY psycopg_items (embedding, half_embedding, binary_embedding, sparse_embedding) FROM STDIN") as copy:
125-
copy.write_row([embedding, HalfVector(embedding), '101', SparseVector.from_dense(embedding)])
125+
copy.write_row([embedding, HalfVector(embedding), '101', SparseVector(embedding)])
126126

127127
def test_binary_copy(self):
128128
embedding = np.array([1.5, 2, 3])
129129
cur = conn.cursor()
130130
with cur.copy("COPY psycopg_items (embedding, half_embedding, binary_embedding, sparse_embedding) FROM STDIN WITH (FORMAT BINARY)") as copy:
131-
copy.write_row([embedding, HalfVector(embedding), Bit('101'), SparseVector.from_dense(embedding)])
131+
copy.write_row([embedding, HalfVector(embedding), Bit('101'), SparseVector(embedding)])
132132

133133
def test_binary_copy_set_types(self):
134134
embedding = np.array([1.5, 2, 3])
135135
cur = conn.cursor()
136136
with cur.copy("COPY psycopg_items (id, embedding, half_embedding, binary_embedding, sparse_embedding) FROM STDIN WITH (FORMAT BINARY)") as copy:
137137
copy.set_types(['int8', 'vector', 'halfvec', 'bit', 'sparsevec'])
138-
copy.write_row([1, embedding, HalfVector(embedding), Bit('101'), SparseVector.from_dense(embedding)])
138+
copy.write_row([1, embedding, HalfVector(embedding), Bit('101'), SparseVector(embedding)])
139139

140140
@pytest.mark.asyncio
141141
async def test_async(self):

tests/test_psycopg2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_bit(self):
4646
assert res[1][0] is None
4747

4848
def test_sparsevec(self):
49-
embedding = SparseVector.from_dense([1.5, 2, 3])
49+
embedding = SparseVector([1.5, 2, 3])
5050
cur.execute('INSERT INTO psycopg2_items (sparse_embedding) VALUES (%s), (NULL)', (embedding,))
5151

5252
cur.execute('SELECT sparse_embedding FROM psycopg2_items ORDER BY id')

tests/test_sparse_vector.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,42 @@
66

77
class TestSparseVector:
88
def test_from_dense(self):
9-
assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_list() == [1, 0, 2, 0, 3, 0]
10-
assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_numpy().tolist() == [1, 0, 2, 0, 3, 0]
11-
assert SparseVector.from_dense(np.array([1, 0, 2, 0, 3, 0])).to_list() == [1, 0, 2, 0, 3, 0]
9+
assert SparseVector([1, 0, 2, 0, 3, 0]).to_list() == [1, 0, 2, 0, 3, 0]
10+
assert SparseVector([1, 0, 2, 0, 3, 0]).to_numpy().tolist() == [1, 0, 2, 0, 3, 0]
11+
assert SparseVector(np.array([1, 0, 2, 0, 3, 0])).to_list() == [1, 0, 2, 0, 3, 0]
12+
13+
def test_from_dense_dimensions(self):
14+
with pytest.raises(ValueError) as error:
15+
SparseVector([1, 0, 2, 0, 3, 0], 6)
16+
assert str(error.value) == 'dimensions not allowed'
1217

1318
def test_from_dict(self):
14-
assert SparseVector.from_dict({0: 1, 2: 2, 4: 3}, 6).to_list() == [1, 0, 2, 0, 3, 0]
19+
assert SparseVector({0: 1, 2: 2, 4: 3}, 6).to_list() == [1, 0, 2, 0, 3, 0]
20+
21+
def test_from_dict_no_dimensions(self):
22+
with pytest.raises(ValueError) as error:
23+
SparseVector({0: 1, 2: 2, 4: 3})
24+
assert str(error.value) == 'dimensions required'
1525

1626
def test_from_sparse(self):
1727
arr = coo_array(np.array([1, 0, 2, 0, 3, 0]))
18-
assert SparseVector.from_sparse(arr).to_list() == [1, 0, 2, 0, 3, 0]
19-
assert SparseVector.from_sparse(arr.todok()).to_list() == [1, 0, 2, 0, 3, 0]
28+
assert SparseVector(arr).to_list() == [1, 0, 2, 0, 3, 0]
29+
assert SparseVector(arr.todok()).to_list() == [1, 0, 2, 0, 3, 0]
30+
31+
def test_from_sparse_dimensions(self):
32+
with pytest.raises(ValueError) as error:
33+
SparseVector(coo_array(np.array([1, 0, 2, 0, 3, 0])), 6)
34+
assert str(error.value) == 'dimensions not allowed'
2035

2136
def test_repr(self):
22-
assert repr(SparseVector.from_dense([1, 0, 2, 0, 3, 0])) == 'SparseVector(6, [0, 2, 4], [1.0, 2.0, 3.0])'
23-
assert str(SparseVector.from_dense([1, 0, 2, 0, 3, 0])) == 'SparseVector(6, [0, 2, 4], [1.0, 2.0, 3.0])'
37+
assert repr(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)'
38+
assert str(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)'
2439

2540
def test_dim(self):
26-
assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).dim() == 6
41+
assert SparseVector([1, 0, 2, 0, 3, 0]).dim() == 6
2742

2843
def test_to_dict(self):
29-
assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_dict() == {0: 1, 2: 2, 4: 3}
44+
assert SparseVector([1, 0, 2, 0, 3, 0]).to_dict() == {0: 1, 2: 2, 4: 3}
3045

3146
def test_to_coo(self):
32-
assert SparseVector.from_dense([1, 0, 2, 0, 3, 0]).to_coo().toarray().tolist() == [[1, 0, 2, 0, 3, 0]]
47+
assert SparseVector([1, 0, 2, 0, 3, 0]).to_coo().toarray().tolist() == [[1, 0, 2, 0, 3, 0]]

tests/test_sqlalchemy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class Item(Base):
4141

4242
def create_items():
4343
session = Session(engine)
44-
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1])))
45-
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2])))
46-
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2])))
44+
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
45+
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
46+
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
4747
session.commit()
4848

4949

tests/test_sqlmodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ class Item(SQLModel, table=True):
3737

3838
def create_items():
3939
session = Session(engine)
40-
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector.from_dense([1, 1, 1])))
41-
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector.from_dense([2, 2, 2])))
42-
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector.from_dense([1, 1, 2])))
40+
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
41+
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
42+
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
4343
session.commit()
4444

4545

0 commit comments

Comments
 (0)