|
| 1 | +from math import sqrt |
| 2 | +import numpy as np |
| 3 | +from peewee import Model, PostgresqlDatabase |
| 4 | +from pgvector.peewee import VectorField |
| 5 | + |
| 6 | +db = PostgresqlDatabase('pgvector_python_test') |
| 7 | + |
| 8 | + |
| 9 | +class BaseModel(Model): |
| 10 | + class Meta: |
| 11 | + database = db |
| 12 | + |
| 13 | + |
| 14 | +class Item(BaseModel): |
| 15 | + embedding = VectorField(dimensions=3) |
| 16 | + |
| 17 | + |
| 18 | +db.connect() |
| 19 | +db.execute_sql('CREATE EXTENSION IF NOT EXISTS vector') |
| 20 | +db.drop_tables([Item]) |
| 21 | +db.create_tables([Item]) |
| 22 | + |
| 23 | + |
| 24 | +def create_items(): |
| 25 | + vectors = [ |
| 26 | + [1, 1, 1], |
| 27 | + [2, 2, 2], |
| 28 | + [1, 1, 2] |
| 29 | + ] |
| 30 | + for i, v in enumerate(vectors): |
| 31 | + Item.create(id=i + 1, embedding=v) |
| 32 | + |
| 33 | + |
| 34 | +class TestPeewee: |
| 35 | + def setup_method(self, test_method): |
| 36 | + Item.truncate_table() |
| 37 | + |
| 38 | + def test_works(self): |
| 39 | + Item.create(id=1, embedding=[1, 2, 3]) |
| 40 | + item = Item.get_by_id(1) |
| 41 | + assert np.array_equal(item.embedding, np.array([1, 2, 3])) |
| 42 | + assert item.embedding.dtype == np.float32 |
| 43 | + |
| 44 | + def test_l2_distance(self): |
| 45 | + create_items() |
| 46 | + distance = Item.embedding.l2_distance([1, 1, 1]) |
| 47 | + items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5) |
| 48 | + assert [v.id for v in items] == [1, 3, 2] |
| 49 | + assert [v.distance for v in items] == [0, 1, sqrt(3)] |
| 50 | + |
| 51 | + def test_max_inner_product(self): |
| 52 | + create_items() |
| 53 | + distance = Item.embedding.max_inner_product([1, 1, 1]) |
| 54 | + items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5) |
| 55 | + assert [v.id for v in items] == [2, 3, 1] |
| 56 | + assert [v.distance for v in items] == [-6, -4, -3] |
| 57 | + |
| 58 | + def test_cosine_distance(self): |
| 59 | + create_items() |
| 60 | + distance = Item.embedding.cosine_distance([1, 1, 1]) |
| 61 | + items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5) |
| 62 | + assert [v.id for v in items] == [1, 2, 3] |
| 63 | + assert [v.distance for v in items] == [0, 0, 0.05719095841793653] |
| 64 | + |
| 65 | + def test_where(self): |
| 66 | + create_items() |
| 67 | + items = Item.select().where(Item.embedding.l2_distance([1, 1, 1]) < 1) |
| 68 | + assert [v.id for v in items] == [1] |
0 commit comments