Skip to content

Commit b43b58f

Browse files
ankane"domenico.cinque"
andcommitted
Use classmethod decorator - closes pgvector#72
Co-authored-by: "domenico.cinque" <domenico.cinque@immobiliare.it>
1 parent f1ce5f3 commit b43b58f

4 files changed

Lines changed: 60 additions & 43 deletions

File tree

pgvector/utils/bit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ def from_binary(value):
3737
buf = np.frombuffer(value, dtype=np.uint8, offset=4)
3838
return Bit(np.unpackbits(buf, count=count).astype(bool))
3939

40-
def _to_db(value):
41-
if not isinstance(value, Bit):
40+
@classmethod
41+
def _to_db(cls, value):
42+
if not isinstance(value, cls):
4243
raise ValueError('expected bit')
4344

4445
return value.to_text()
4546

46-
def _to_db_binary(value):
47-
if not isinstance(value, Bit):
47+
@classmethod
48+
def _to_db_binary(cls, value):
49+
if not isinstance(value, cls):
4850
raise ValueError('expected bit')
4951

5052
return value.to_binary()

pgvector/utils/halfvec.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,42 +31,48 @@ def to_text(self):
3131
def to_binary(self):
3232
return pack('>HH', self.dim(), 0) + self._value.tobytes()
3333

34-
def from_text(value):
35-
return HalfVector([float(v) for v in value[1:-1].split(',')])
34+
@classmethod
35+
def from_text(cls, value):
36+
return cls([float(v) for v in value[1:-1].split(',')])
3637

37-
def from_binary(value):
38+
@classmethod
39+
def from_binary(cls, value):
3840
dim, unused = unpack_from('>HH', value)
39-
return HalfVector(np.frombuffer(value, dtype='>f2', count=dim, offset=4))
41+
return cls(np.frombuffer(value, dtype='>f2', count=dim, offset=4))
4042

41-
def _to_db(value, dim=None):
43+
@classmethod
44+
def _to_db(cls, value, dim=None):
4245
if value is None:
4346
return value
4447

45-
if not isinstance(value, HalfVector):
46-
value = HalfVector(value)
48+
if not isinstance(value, cls):
49+
value = cls(value)
4750

4851
if dim is not None and value.dim() != dim:
4952
raise ValueError('expected %d dimensions, not %d' % (dim, value.dim()))
5053

5154
return value.to_text()
5255

53-
def _to_db_binary(value):
56+
@classmethod
57+
def _to_db_binary(cls, value):
5458
if value is None:
5559
return value
5660

57-
if not isinstance(value, HalfVector):
58-
value = HalfVector(value)
61+
if not isinstance(value, cls):
62+
value = cls(value)
5963

6064
return value.to_binary()
6165

62-
def _from_db(value):
63-
if value is None or isinstance(value, HalfVector):
66+
@classmethod
67+
def _from_db(cls, value):
68+
if value is None or isinstance(value, cls):
6469
return value
6570

66-
return HalfVector.from_text(value)
71+
return cls.from_text(value)
6772

68-
def _from_db_binary(value):
69-
if value is None or isinstance(value, HalfVector):
73+
@classmethod
74+
def _from_db_binary(cls, value):
75+
if value is None or isinstance(value, cls):
7076
return value
7177

72-
return HalfVector.from_binary(value)
78+
return cls.from_binary(value)

pgvector/utils/sparsevec.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,41 +56,46 @@ def from_binary(value):
5656
values = unpack_from(f'>{nnz}f', value, 12 + nnz * 4)
5757
return SparseVector(int(dim), indices, values)
5858

59-
def _to_db(value, dim=None):
59+
@classmethod
60+
def _to_db(cls, value, dim=None):
6061
if value is None:
6162
return value
6263

63-
value = __class__._to_db_value(value)
64+
value = cls._to_db_value(value)
6465

6566
if dim is not None and value.dim() != dim:
6667
raise ValueError('expected %d dimensions, not %d' % (dim, value.dim()))
6768

6869
return value.to_text()
6970

70-
def _to_db_binary(value):
71+
@classmethod
72+
def _to_db_binary(cls, value):
7173
if value is None:
7274
return value
7375

74-
value = __class__._to_db_value(value)
76+
value = cls._to_db_value(value)
7577

7678
return value.to_binary()
7779

78-
def _to_db_value(value):
79-
if isinstance(value, SparseVector):
80+
@classmethod
81+
def _to_db_value(cls, value):
82+
if isinstance(value, cls):
8083
return value
8184
elif isinstance(value, (list, np.ndarray)):
82-
return SparseVector.from_dense(value)
85+
return cls.from_dense(value)
8386
else:
8487
raise ValueError('expected sparsevec')
8588

86-
def _from_db(value):
87-
if value is None or isinstance(value, SparseVector):
89+
@classmethod
90+
def _from_db(cls, value):
91+
if value is None or isinstance(value, cls):
8892
return value
8993

90-
return SparseVector.from_text(value)
94+
return cls.from_text(value)
9195

92-
def _from_db_binary(value):
93-
if value is None or isinstance(value, SparseVector):
96+
@classmethod
97+
def _from_db_binary(cls, value):
98+
if value is None or isinstance(value, cls):
9499
return value
95100

96-
return SparseVector.from_binary(value)
101+
return cls.from_binary(value)

pgvector/utils/vector.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,39 @@ def from_binary(value):
3838
dim, unused = unpack_from('>HH', value)
3939
return Vector(np.frombuffer(value, dtype='>f4', count=dim, offset=4))
4040

41-
def _to_db(value, dim=None):
41+
@classmethod
42+
def _to_db(cls, value, dim=None):
4243
if value is None:
4344
return value
4445

45-
if not isinstance(value, Vector):
46-
value = Vector(value)
46+
if not isinstance(value, cls):
47+
value = cls(value)
4748

4849
if dim is not None and value.dim() != dim:
4950
raise ValueError('expected %d dimensions, not %d' % (dim, value.dim()))
5051

5152
return value.to_text()
5253

53-
def _to_db_binary(value):
54+
@classmethod
55+
def _to_db_binary(cls, value):
5456
if value is None:
5557
return value
5658

57-
if not isinstance(value, Vector):
58-
value = Vector(value)
59+
if not isinstance(value, cls):
60+
value = cls(value)
5961

6062
return value.to_binary()
6163

62-
def _from_db(value):
64+
@classmethod
65+
def _from_db(cls, value):
6366
if value is None or isinstance(value, np.ndarray):
6467
return value
6568

66-
return Vector.from_text(value).to_numpy().astype(np.float32)
69+
return cls.from_text(value).to_numpy().astype(np.float32)
6770

68-
def _from_db_binary(value):
71+
@classmethod
72+
def _from_db_binary(cls, value):
6973
if value is None or isinstance(value, np.ndarray):
7074
return value
7175

72-
return Vector.from_binary(value).to_numpy().astype(np.float32)
76+
return cls.from_binary(value).to_numpy().astype(np.float32)

0 commit comments

Comments
 (0)