|
1 | | -from sqlalchemy.dialects.postgresql.base import ischema_names |
2 | | -from sqlalchemy.types import UserDefinedType, Float, String |
3 | | -from ..utils import from_db, to_db |
| 1 | +from .vector import Vector |
4 | 2 |
|
5 | 3 | __all__ = ['Vector'] |
6 | | - |
7 | | - |
8 | | -class Vector(UserDefinedType): |
9 | | - cache_ok = True |
10 | | - _string = String() |
11 | | - |
12 | | - def __init__(self, dim=None): |
13 | | - super(UserDefinedType, self).__init__() |
14 | | - self.dim = dim |
15 | | - |
16 | | - def get_col_spec(self, **kw): |
17 | | - if self.dim is None: |
18 | | - return 'VECTOR' |
19 | | - return 'VECTOR(%d)' % self.dim |
20 | | - |
21 | | - def bind_processor(self, dialect): |
22 | | - def process(value): |
23 | | - return to_db(value, self.dim) |
24 | | - return process |
25 | | - |
26 | | - def literal_processor(self, dialect): |
27 | | - string_literal_processor = self._string._cached_literal_processor(dialect) |
28 | | - |
29 | | - def process(value): |
30 | | - return string_literal_processor(to_db(value, self.dim)) |
31 | | - return process |
32 | | - |
33 | | - def result_processor(self, dialect, coltype): |
34 | | - def process(value): |
35 | | - return from_db(value) |
36 | | - return process |
37 | | - |
38 | | - class comparator_factory(UserDefinedType.Comparator): |
39 | | - def l2_distance(self, other): |
40 | | - return self.op('<->', return_type=Float)(other) |
41 | | - |
42 | | - def max_inner_product(self, other): |
43 | | - return self.op('<#>', return_type=Float)(other) |
44 | | - |
45 | | - def cosine_distance(self, other): |
46 | | - return self.op('<=>', return_type=Float)(other) |
47 | | - |
48 | | - def l1_distance(self, other): |
49 | | - return self.op('<+>', return_type=Float)(other) |
50 | | - |
51 | | - |
52 | | -# for reflection |
53 | | -ischema_names['vector'] = Vector |
0 commit comments