Skip to content

Commit b4b6c6b

Browse files
committed
Added support for halfvec type to SQLAlchemy and SQLModel [skip ci]
1 parent 70c1f58 commit b4b6c6b

File tree

5 files changed

+62
-4
lines changed

5 files changed

+62
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.2.6 (unreleased)
22

33
- Added support for `halfvec` and `sparsevec` types to Django
4+
- Added support for `halfvec` type to SQLAlchemy and SQLModel
45
- Added support for `halfvec` and `sparsevec` types to Psycopg 3
56
- Added support for `halfvec` and `sparsevec` types to Psycopg 2
67
- Added support for `halfvec` and `sparsevec` types to asyncpg

pgvector/sqlalchemy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .halfvec import Halfvec
12
from .vector import Vector
23

3-
__all__ = ['Vector']
4+
__all__ = ['Vector', 'Halfvec']

pgvector/sqlalchemy/halfvec.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from sqlalchemy.dialects.postgresql.base import ischema_names
2+
from sqlalchemy.types import UserDefinedType, Float, String
3+
from ..utils import HalfVec
4+
5+
6+
class Halfvec(UserDefinedType):
7+
cache_ok = True
8+
_string = String()
9+
10+
def __init__(self, dim=None):
11+
super(UserDefinedType, self).__init__()
12+
self.dim = dim
13+
14+
def get_col_spec(self, **kw):
15+
if self.dim is None:
16+
return 'HALFVEC'
17+
return 'HALFVEC(%d)' % self.dim
18+
19+
def bind_processor(self, dialect):
20+
def process(value):
21+
return HalfVec.to_db(value, self.dim)
22+
return process
23+
24+
def literal_processor(self, dialect):
25+
string_literal_processor = self._string._cached_literal_processor(dialect)
26+
27+
def process(value):
28+
return string_literal_processor(HalfVec.to_db(value, self.dim))
29+
return process
30+
31+
def result_processor(self, dialect, coltype):
32+
def process(value):
33+
return HalfVec.from_db(value)
34+
return process
35+
36+
class comparator_factory(UserDefinedType.Comparator):
37+
def l2_distance(self, other):
38+
return self.op('<->', return_type=Float)(other)
39+
40+
def max_inner_product(self, other):
41+
return self.op('<#>', return_type=Float)(other)
42+
43+
def cosine_distance(self, other):
44+
return self.op('<=>', return_type=Float)(other)
45+
46+
def l1_distance(self, other):
47+
return self.op('<+>', return_type=Float)(other)
48+
49+
50+
# for reflection
51+
ischema_names['halfvec'] = Halfvec

pgvector/utils/halfvec.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@ def __init__(self, value):
1212
def to_list(self):
1313
return list(self.value)
1414

15-
def to_db(value):
15+
def to_db(value, dim=None):
1616
if value is None:
1717
return value
1818
if isinstance(value, HalfVec):
1919
value = value.value
20+
21+
if dim is not None and len(value) != dim:
22+
raise ValueError('expected %d dimensions, not %d' % (dim, len(value)))
23+
2024
return '[' + ','.join([str(float(v)) for v in value]) + ']'
2125

2226
def to_db_binary(value):

tests/test_sqlalchemy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector.sqlalchemy import Vector
2+
from pgvector.sqlalchemy import Vector, Halfvec
33
import pytest
44
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer
55
from sqlalchemy.exc import StatementError
@@ -20,6 +20,7 @@ class Item(Base):
2020

2121
id = mapped_column(Integer, primary_key=True)
2222
embedding = mapped_column(Vector(3))
23+
half_embedding = mapped_column(Halfvec(3))
2324

2425

2526
Base.metadata.drop_all(engine)
@@ -43,7 +44,7 @@ def create_items():
4344
]
4445
session = Session(engine)
4546
for i, v in enumerate(vectors):
46-
session.add(Item(id=i + 1, embedding=v))
47+
session.add(Item(id=i + 1, embedding=v, half_embedding=v))
4748
session.commit()
4849

4950

0 commit comments

Comments
 (0)