Skip to content

Commit ff9012e

Browse files
ankanebormanjo
andcommitted
Added support for Peewee - closes pgvector#26 and closes pgvector#27
Co-authored-by: John-Craig Borman <borman.johncraig@gmail.com>
1 parent 6867e8b commit ff9012e

6 files changed

Lines changed: 141 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.2.2 (unreleased)
22

3+
- Added support for Peewee
34
- Added `HnswIndex` for Django
45

56
## 0.2.1 (2023-07-31)

README.md

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[pgvector](https://github.com/pgvector/pgvector) support for Python
44

5-
Supports [Django](https://github.com/django/django), [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), [SQLModel](https://github.com/tiangolo/sqlmodel), [Psycopg 3](https://github.com/psycopg/psycopg), [Psycopg 2](https://github.com/psycopg/psycopg2), and [asyncpg](https://github.com/MagicStack/asyncpg)
5+
Supports [Django](https://github.com/django/django), [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), [SQLModel](https://github.com/tiangolo/sqlmodel), [Psycopg 3](https://github.com/psycopg/psycopg), [Psycopg 2](https://github.com/psycopg/psycopg2), [asyncpg](https://github.com/MagicStack/asyncpg), and [Peewee](https://github.com/coleifer/peewee)
66

77
[![Build Status](https://github.com/pgvector/pgvector-python/workflows/build/badge.svg?branch=master)](https://github.com/pgvector/pgvector-python/actions)
88

@@ -22,6 +22,7 @@ And follow the instructions for your database library:
2222
- [Psycopg 3](#psycopg-3)
2323
- [Psycopg 2](#psycopg-2)
2424
- [asyncpg](#asyncpg)
25+
- [Peewee](#peewee) [unreleased]
2526

2627
Or check out some examples:
2728

@@ -270,6 +271,43 @@ Get the nearest neighbors to a vector
270271
await conn.fetch('SELECT * FROM item ORDER BY embedding <-> $1 LIMIT 5', embedding)
271272
```
272273

274+
## Peewee
275+
276+
Add a vector column
277+
278+
```python
279+
from pgvector.peewee import VectorField
280+
281+
class Item(BaseModel):
282+
embedding = VectorField(dimensions=3)
283+
```
284+
285+
Insert a vector
286+
287+
```python
288+
item = Item.create(embedding=[1, 2, 3])
289+
```
290+
291+
Get the nearest neighbors to a vector
292+
293+
```python
294+
Item.select().order_by(Item.embedding.l2_distance([3, 1, 2])).limit(5)
295+
```
296+
297+
Also supports `max_inner_product` and `cosine_distance`
298+
299+
Get the distance
300+
301+
```python
302+
Item.select(Item.embedding.l2_distance([3, 1, 2]).alias('distance'))
303+
```
304+
305+
Get items within a certain distance
306+
307+
```python
308+
Item.select().where(Item.embedding.l2_distance([3, 1, 2]) < 5)
309+
```
310+
273311
## History
274312

275313
View the [changelog](https://github.com/pgvector/pgvector-python/blob/master/CHANGELOG.md)

pgvector/peewee/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from peewee import Expression, Field, Value
2+
from ..utils import from_db, to_db
3+
4+
5+
class VectorField(Field):
6+
field_type = 'vector'
7+
8+
def __init__(self, dimensions=None, *args, **kwargs):
9+
self.dimensions = dimensions
10+
super(VectorField, self).__init__(*args, **kwargs)
11+
12+
def get_modifiers(self):
13+
return self.dimensions and [self.dimensions] or None
14+
15+
def db_value(self, value):
16+
return to_db(value)
17+
18+
def python_value(self, value):
19+
return from_db(value)
20+
21+
def _distance(self, op, vector):
22+
return Expression(lhs=self, op=op, rhs=Value(vector, converter=to_db, unpack=False))
23+
24+
def l2_distance(self, vector):
25+
return self._distance('<->', vector)
26+
27+
def max_inner_product(self, vector):
28+
return self._distance('<#>', vector)
29+
30+
def cosine_distance(self, vector):
31+
return self._distance('<=>', vector)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
asyncpg
22
Django
33
numpy
4+
peewee
45
psycopg[binary]
56
psycopg2-binary
67
pytest

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
packages=[
1717
'pgvector.asyncpg',
1818
'pgvector.django',
19+
'pgvector.peewee',
1920
'pgvector.psycopg',
2021
'pgvector.psycopg2',
2122
'pgvector.sqlalchemy',

tests/test_peewee.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from math import sqrt
2+
import numpy as np
3+
from peewee import Model, PostgresqlDatabase
4+
from pgvector.peewee import VectorField
5+
6+
db = PostgresqlDatabase('pgvector_python_test')
7+
8+
9+
class BaseModel(Model):
10+
class Meta:
11+
database = db
12+
13+
14+
class Item(BaseModel):
15+
embedding = VectorField(dimensions=3)
16+
17+
18+
db.connect()
19+
db.execute_sql('CREATE EXTENSION IF NOT EXISTS vector')
20+
db.drop_tables([Item])
21+
db.create_tables([Item])
22+
23+
24+
def create_items():
25+
vectors = [
26+
[1, 1, 1],
27+
[2, 2, 2],
28+
[1, 1, 2]
29+
]
30+
for i, v in enumerate(vectors):
31+
Item.create(id=i + 1, embedding=v)
32+
33+
34+
class TestPeewee:
35+
def setup_method(self, test_method):
36+
Item.truncate_table()
37+
38+
def test_works(self):
39+
Item.create(id=1, embedding=[1, 2, 3])
40+
item = Item.get_by_id(1)
41+
assert np.array_equal(item.embedding, np.array([1, 2, 3]))
42+
assert item.embedding.dtype == np.float32
43+
44+
def test_l2_distance(self):
45+
create_items()
46+
distance = Item.embedding.l2_distance([1, 1, 1])
47+
items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5)
48+
assert [v.id for v in items] == [1, 3, 2]
49+
assert [v.distance for v in items] == [0, 1, sqrt(3)]
50+
51+
def test_max_inner_product(self):
52+
create_items()
53+
distance = Item.embedding.max_inner_product([1, 1, 1])
54+
items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5)
55+
assert [v.id for v in items] == [2, 3, 1]
56+
assert [v.distance for v in items] == [-6, -4, -3]
57+
58+
def test_cosine_distance(self):
59+
create_items()
60+
distance = Item.embedding.cosine_distance([1, 1, 1])
61+
items = Item.select(Item.id, distance.alias('distance')).order_by(distance).limit(5)
62+
assert [v.id for v in items] == [1, 2, 3]
63+
assert [v.distance for v in items] == [0, 0, 0.05719095841793653]
64+
65+
def test_where(self):
66+
create_items()
67+
items = Item.select().where(Item.embedding.l2_distance([1, 1, 1]) < 1)
68+
assert [v.id for v in items] == [1]

0 commit comments

Comments
 (0)