|
1 | 1 | from math import sqrt |
2 | 2 | import numpy as np |
3 | | -from peewee import Model, PostgresqlDatabase |
| 3 | +from peewee import Model, PostgresqlDatabase, fn |
4 | 4 | from pgvector.peewee import VectorField |
5 | 5 |
|
6 | 6 | db = PostgresqlDatabase('pgvector_python_test') |
@@ -69,6 +69,22 @@ def test_where(self): |
69 | 69 | items = Item.select().where(Item.embedding.l2_distance([1, 1, 1]) < 1) |
70 | 70 | assert [v.id for v in items] == [1] |
71 | 71 |
|
| 72 | + def test_avg(self): |
| 73 | + avg = Item.select(fn.avg(Item.embedding)).scalar() |
| 74 | + assert avg is None |
| 75 | + Item.create(embedding=[1, 2, 3]) |
| 76 | + Item.create(embedding=[4, 5, 6]) |
| 77 | + avg = Item.select(fn.avg(Item.embedding)).scalar() |
| 78 | + assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) |
| 79 | + |
| 80 | + def test_sum(self): |
| 81 | + sum = Item.select(fn.sum(Item.embedding)).scalar() |
| 82 | + assert sum is None |
| 83 | + Item.create(embedding=[1, 2, 3]) |
| 84 | + Item.create(embedding=[4, 5, 6]) |
| 85 | + sum = Item.select(fn.sum(Item.embedding)).scalar() |
| 86 | + assert np.array_equal(sum, np.array([5, 7, 9])) |
| 87 | + |
72 | 88 | def test_get_or_create(self): |
73 | 89 | Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]}) |
74 | 90 | Item.get_or_create(embedding=np.array([4, 5, 6])) |
|
0 commit comments