Skip to content

Commit 557983b

Browse files
jerinpetergeorgeankane
authored andcommitted
Enhances DjangoAdmin support (pgvector#22)
1 parent 4ea06c0 commit 557983b

File tree

4 files changed

+38
-3
lines changed

4 files changed

+38
-3
lines changed

pgvector/django/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from django.contrib.postgres.operations import CreateExtension
22
from django.contrib.postgres.indexes import PostgresIndex
33
from django.db.models import Field, FloatField, Func, Value
4+
from .forms import VectorFormField
45
from ..utils import from_db, to_db
56

67
__all__ = ['VectorExtension', 'VectorField', 'IvfflatIndex', 'L2Distance', 'MaxInnerProduct', 'CosineDistance']
@@ -48,6 +49,9 @@ def validate(self, value, model_instance):
4849
def run_validators(self, value):
4950
super().run_validators(value.tolist())
5051

52+
def formfield(self, **kwargs):
53+
return super().formfield(form_class=VectorFormField, **kwargs)
54+
5155

5256
class IvfflatIndex(PostgresIndex):
5357
suffix = 'ivfflat'

pgvector/django/forms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from django import forms
2+
from .widgets import VectorWidget
3+
4+
5+
class VectorFormField(forms.CharField):
6+
widget = VectorWidget
7+
8+
def has_changed(self, initial, data):
9+
try:
10+
initial = initial.tolist()
11+
except AttributeError:
12+
# initial could be None
13+
pass
14+
return super().has_changed(initial, data)

pgvector/django/widgets.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from django import forms
2+
3+
4+
class VectorWidget(forms.TextInput):
5+
def format_value(self, value):
6+
try:
7+
value = value.tolist()
8+
except AttributeError:
9+
# value could be None
10+
pass
11+
return super().format_value(value)

tests/test_django.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,16 @@ def test_serialization(self):
136136
obj.save()
137137

138138
def test_form(self):
139-
form = ItemForm(data={'embedding': [1, 2, 3]})
139+
form = ItemForm(data={'embedding': "[1, 2, 3]"})
140140
assert form.is_valid()
141141
assert 'value="[1, 2, 3]"' in form.as_div()
142142

143+
def test_form_has_changed(self):
144+
Item(id=1, embedding=[1, 2, 3]).save()
145+
item = Item.objects.get(pk=1)
146+
form = ItemForm(instance=item, data={'embedding': "[1, 2, 4]"})
147+
assert form.has_changed()
148+
143149
def test_form_instance(self):
144150
Item(id=1, embedding=[1, 2, 3]).save()
145151
item = Item.objects.get(pk=1)
@@ -150,7 +156,7 @@ def test_form_instance(self):
150156
def test_form_save(self):
151157
Item(id=1, embedding=[1, 2, 3]).save()
152158
item = Item.objects.get(pk=1)
153-
form = ItemForm(instance=item, data={'embedding': [4, 5, 6]})
159+
form = ItemForm(instance=item, data={'embedding': "[4, 5, 6]"})
154160
assert form.is_valid()
155161
assert form.save()
156-
assert [4, 5, 6], Item.objects.get(pk=1).embedding
162+
assert [4, 5, 6] == Item.objects.get(pk=1).embedding.tolist()

0 commit comments

Comments
 (0)