Skip to content

Commit 446859c

Browse files
committed
Added tests for aggregates with Peewee
1 parent caed7d6 commit 446859c

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

tests/test_peewee.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from math import sqrt
22
import numpy as np
3-
from peewee import Model, PostgresqlDatabase
3+
from peewee import Model, PostgresqlDatabase, fn
44
from pgvector.peewee import VectorField
55

66
db = PostgresqlDatabase('pgvector_python_test')
@@ -69,6 +69,22 @@ def test_where(self):
6969
items = Item.select().where(Item.embedding.l2_distance([1, 1, 1]) < 1)
7070
assert [v.id for v in items] == [1]
7171

72+
def test_avg(self):
73+
avg = Item.select(fn.avg(Item.embedding)).scalar()
74+
assert avg is None
75+
Item.create(embedding=[1, 2, 3])
76+
Item.create(embedding=[4, 5, 6])
77+
avg = Item.select(fn.avg(Item.embedding)).scalar()
78+
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
79+
80+
def test_sum(self):
81+
sum = Item.select(fn.sum(Item.embedding)).scalar()
82+
assert sum is None
83+
Item.create(embedding=[1, 2, 3])
84+
Item.create(embedding=[4, 5, 6])
85+
sum = Item.select(fn.sum(Item.embedding)).scalar()
86+
assert np.array_equal(sum, np.array([5, 7, 9]))
87+
7288
def test_get_or_create(self):
7389
Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})
7490
Item.get_or_create(embedding=np.array([4, 5, 6]))

0 commit comments

Comments
 (0)