Skip to content

Commit afb416f

Browse files
committed
Moved Vector to separate file for SQLAlchemy [skip ci]
1 parent bd9d7d6 commit afb416f

File tree

2 files changed

+52
-51
lines changed

2 files changed

+52
-51
lines changed

pgvector/sqlalchemy/__init__.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,3 @@
1-
from sqlalchemy.dialects.postgresql.base import ischema_names
2-
from sqlalchemy.types import UserDefinedType, Float, String
3-
from ..utils import from_db, to_db
1+
from .vector import Vector
42

53
__all__ = ['Vector']
6-
7-
8-
class Vector(UserDefinedType):
9-
cache_ok = True
10-
_string = String()
11-
12-
def __init__(self, dim=None):
13-
super(UserDefinedType, self).__init__()
14-
self.dim = dim
15-
16-
def get_col_spec(self, **kw):
17-
if self.dim is None:
18-
return 'VECTOR'
19-
return 'VECTOR(%d)' % self.dim
20-
21-
def bind_processor(self, dialect):
22-
def process(value):
23-
return to_db(value, self.dim)
24-
return process
25-
26-
def literal_processor(self, dialect):
27-
string_literal_processor = self._string._cached_literal_processor(dialect)
28-
29-
def process(value):
30-
return string_literal_processor(to_db(value, self.dim))
31-
return process
32-
33-
def result_processor(self, dialect, coltype):
34-
def process(value):
35-
return from_db(value)
36-
return process
37-
38-
class comparator_factory(UserDefinedType.Comparator):
39-
def l2_distance(self, other):
40-
return self.op('<->', return_type=Float)(other)
41-
42-
def max_inner_product(self, other):
43-
return self.op('<#>', return_type=Float)(other)
44-
45-
def cosine_distance(self, other):
46-
return self.op('<=>', return_type=Float)(other)
47-
48-
def l1_distance(self, other):
49-
return self.op('<+>', return_type=Float)(other)
50-
51-
52-
# for reflection
53-
ischema_names['vector'] = Vector

pgvector/sqlalchemy/vector.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from sqlalchemy.dialects.postgresql.base import ischema_names
2+
from sqlalchemy.types import UserDefinedType, Float, String
3+
from ..utils import from_db, to_db
4+
5+
6+
class Vector(UserDefinedType):
7+
cache_ok = True
8+
_string = String()
9+
10+
def __init__(self, dim=None):
11+
super(UserDefinedType, self).__init__()
12+
self.dim = dim
13+
14+
def get_col_spec(self, **kw):
15+
if self.dim is None:
16+
return 'VECTOR'
17+
return 'VECTOR(%d)' % self.dim
18+
19+
def bind_processor(self, dialect):
20+
def process(value):
21+
return to_db(value, self.dim)
22+
return process
23+
24+
def literal_processor(self, dialect):
25+
string_literal_processor = self._string._cached_literal_processor(dialect)
26+
27+
def process(value):
28+
return string_literal_processor(to_db(value, self.dim))
29+
return process
30+
31+
def result_processor(self, dialect, coltype):
32+
def process(value):
33+
return from_db(value)
34+
return process
35+
36+
class comparator_factory(UserDefinedType.Comparator):
37+
def l2_distance(self, other):
38+
return self.op('<->', return_type=Float)(other)
39+
40+
def max_inner_product(self, other):
41+
return self.op('<#>', return_type=Float)(other)
42+
43+
def cosine_distance(self, other):
44+
return self.op('<=>', return_type=Float)(other)
45+
46+
def l1_distance(self, other):
47+
return self.op('<+>', return_type=Float)(other)
48+
49+
50+
# for reflection
51+
ischema_names['vector'] = Vector

0 commit comments

Comments
 (0)