Skip to content

Commit b1d2cdb

Browse files
committed
Improved SparseVector code [skip ci]
1 parent 9f5f4eb commit b1d2cdb

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

pgvector/utils/sparsevec.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ def to_db_value(value):
1313

1414
class SparseVector:
1515
def __init__(self, dim, indices, values):
16-
self._dim = dim
17-
self._indices = indices
18-
self._values = values
16+
# TODO improve
17+
self._dim = int(dim)
18+
self._indices = [int(i) for i in indices]
19+
self._values = [float(v) for v in values]
1920

2021
def __repr__(self):
2122
return f'SparseVector({self._dim}, {self._indices}, {self._values})'
@@ -25,7 +26,7 @@ def from_dense(value):
2526
value = value.tolist()
2627
dim = len(value)
2728
indices = [i for i, v in enumerate(value) if v != 0]
28-
values = [value[i] for i in indices]
29+
values = [float(value[i]) for i in indices]
2930
return SparseVector(dim, indices, values)
3031

3132
def dim(self):

tests/test_sparse_vector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ def test_from_dense(self):
99
assert SparseVector.from_dense([1, 2, 3]).to_numpy().tolist() == [1, 2, 3]
1010

1111
def test_repr(self):
12-
assert repr(SparseVector.from_dense([1, 2, 3])) == 'SparseVector(3, [0, 1, 2], [1, 2, 3])'
13-
assert str(SparseVector.from_dense([1, 2, 3])) == 'SparseVector(3, [0, 1, 2], [1, 2, 3])'
12+
assert repr(SparseVector.from_dense([1, 2, 3])) == 'SparseVector(3, [0, 1, 2], [1.0, 2.0, 3.0])'
13+
assert str(SparseVector.from_dense([1, 2, 3])) == 'SparseVector(3, [0, 1, 2], [1.0, 2.0, 3.0])'

0 commit comments

Comments
 (0)