Skip to content

Commit f1cf31e

Browse files
committed
Added support for sparsevec type to SQLAlchemy and SQLModel [skip ci]
1 parent b4b6c6b commit f1cf31e

File tree

5 files changed

+64
-6
lines changed

5 files changed

+64
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
## 0.2.6 (unreleased)
1+
## 0.3.0 (unreleased)
22

33
- Added support for `halfvec` and `sparsevec` types to Django
4-
- Added support for `halfvec` type to SQLAlchemy and SQLModel
4+
- Added support for `halfvec` and `sparsevec` types to SQLAlchemy and SQLModel
55
- Added support for `halfvec` and `sparsevec` types to Psycopg 3
66
- Added support for `halfvec` and `sparsevec` types to Psycopg 2
77
- Added support for `halfvec` and `sparsevec` types to asyncpg

pgvector/sqlalchemy/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .halfvec import Halfvec
2+
from .sparsevec import Sparsevec
23
from .vector import Vector
4+
from ..utils import SparseVec
35

4-
__all__ = ['Vector', 'Halfvec']
6+
__all__ = ['Vector', 'Halfvec', 'Sparsevec']

pgvector/sqlalchemy/sparsevec.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from sqlalchemy.dialects.postgresql.base import ischema_names
2+
from sqlalchemy.types import UserDefinedType, Float, String
3+
from ..utils import SparseVec
4+
5+
6+
class Sparsevec(UserDefinedType):
7+
cache_ok = True
8+
_string = String()
9+
10+
def __init__(self, dim=None):
11+
super(UserDefinedType, self).__init__()
12+
self.dim = dim
13+
14+
def get_col_spec(self, **kw):
15+
if self.dim is None:
16+
return 'SPARSEVEC'
17+
return 'SPARSEVEC(%d)' % self.dim
18+
19+
def bind_processor(self, dialect):
20+
def process(value):
21+
return SparseVec.to_db(value, self.dim)
22+
return process
23+
24+
def literal_processor(self, dialect):
25+
string_literal_processor = self._string._cached_literal_processor(dialect)
26+
27+
def process(value):
28+
return string_literal_processor(SparseVec.to_db(value, self.dim))
29+
return process
30+
31+
def result_processor(self, dialect, coltype):
32+
def process(value):
33+
return SparseVec.from_db(value)
34+
return process
35+
36+
class comparator_factory(UserDefinedType.Comparator):
37+
def l2_distance(self, other):
38+
return self.op('<->', return_type=Float)(other)
39+
40+
def max_inner_product(self, other):
41+
return self.op('<#>', return_type=Float)(other)
42+
43+
def cosine_distance(self, other):
44+
return self.op('<=>', return_type=Float)(other)
45+
46+
def l1_distance(self, other):
47+
return self.op('<+>', return_type=Float)(other)
48+
49+
50+
# for reflection
51+
ischema_names['sparsevec'] = Sparsevec

pgvector/utils/sparsevec.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@ def to_dense(self):
1919
vec[i] = v
2020
return vec
2121

22-
def to_db(value):
22+
def to_db(value, dim=None):
2323
if value is None:
2424
return value
25+
26+
if dim is not None and value.dim != dim:
27+
raise ValueError('expected %d dimensions, not %d' % (dim, len(value)))
28+
2529
return '{' + ','.join([f'{i + 1}:{v}' for i, v in zip(value.indices, value.values)]) + '}/' + str(value.dim)
2630

2731
def to_db_binary(value):

tests/test_sqlalchemy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector.sqlalchemy import Vector, Halfvec
2+
from pgvector.sqlalchemy import Vector, Halfvec, Sparsevec, SparseVec
33
import pytest
44
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer
55
from sqlalchemy.exc import StatementError
@@ -21,6 +21,7 @@ class Item(Base):
2121
id = mapped_column(Integer, primary_key=True)
2222
embedding = mapped_column(Vector(3))
2323
half_embedding = mapped_column(Halfvec(3))
24+
sparse_embedding = mapped_column(Sparsevec(3))
2425

2526

2627
Base.metadata.drop_all(engine)
@@ -44,7 +45,7 @@ def create_items():
4445
]
4546
session = Session(engine)
4647
for i, v in enumerate(vectors):
47-
session.add(Item(id=i + 1, embedding=v, half_embedding=v))
48+
session.add(Item(id=i + 1, embedding=v, half_embedding=v, sparse_embedding=SparseVec.from_dense(v)))
4849
session.commit()
4950

5051

0 commit comments

Comments
 (0)