Skip to content

Commit 62c69fc

Browse files
L1 distance support for sqlalchemy and sqlmodel (#69)
1 parent 3c51f39 commit 62c69fc

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

pgvector/sqlalchemy/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class comparator_factory(UserDefinedType.Comparator):
3939
def l2_distance(self, other):
4040
return self.op('<->', return_type=Float)(other)
4141

42+
def l1_distance(self, other):
43+
return self.op('<+>', return_type=Float)(other)
44+
4245
def max_inner_product(self, other):
4346
return self.op('<#>', return_type=Float)(other)
4447

tests/test_sqlalchemy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def test_l2_distance_orm(self):
119119
items = session.scalars(select(Item).order_by(Item.embedding.l2_distance([1, 1, 1])))
120120
assert [v.id for v in items] == [1, 3, 2]
121121

122+
def test_l1_distance(self):
123+
create_items()
124+
with Session(engine) as session:
125+
items = session.query(Item).order_by(Item.embedding.l1_distance([1, 1, 2])).all()
126+
assert [v.id for v in items] == [3, 1, 2]
127+
128+
def test_l1_distance_orm(self):
129+
create_items()
130+
with Session(engine) as session:
131+
items = session.scalars(select(Item).order_by(Item.embedding.l1_distance([1, 1, 2])))
132+
assert [v.id for v in items] == [3, 1, 2]
133+
122134
def test_max_inner_product(self):
123135
create_items()
124136
with Session(engine) as session:

tests/test_sqlmodel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def test_l2_distance(self):
7979
items = session.exec(select(Item).order_by(Item.embedding.l2_distance([1, 1, 1])))
8080
assert [v.id for v in items] == [1, 3, 2]
8181

82+
def test_l1_distance(self):
83+
create_items()
84+
with Session(engine) as session:
85+
items = session.exec(select(Item).order_by(Item.embedding.l1_distance([1, 1, 1])))
86+
assert [v.id for v in items] == [1, 3, 2]
87+
8288
def test_max_inner_product(self):
8389
create_items()
8490
with Session(engine) as session:

0 commit comments

Comments
 (0)