Skip to content

Commit b53a055

Browse files
committed
Improved Doctrine tests
1 parent 0b08597 commit b53a055

2 files changed

Lines changed: 78 additions & 17 deletions

File tree

tests/DoctrineTest.php

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
use Doctrine\DBAL\Types\Type;
88
use Doctrine\ORM\EntityManager;
99
use Doctrine\ORM\ORMSetup;
10+
use Doctrine\ORM\Query\ResultSetMappingBuilder;
1011
use Doctrine\ORM\Tools\SchemaTool;
1112
use Pgvector\HalfVector;
1213
use Pgvector\SparseVector;
@@ -56,6 +57,12 @@ public static function setUpBeforeClass(): void
5657
self::$em = $entityManager;
5758
}
5859

60+
public function setUp(): void
61+
{
62+
self::$em->getConnection()->executeStatement('TRUNCATE doctrine_items RESTART IDENTITY');
63+
self::$em->clear();
64+
}
65+
5966
public function testTypes()
6067
{
6168
$item = new DoctrineItem();
@@ -73,4 +80,48 @@ public function testTypes()
7380
$this->assertEquals('101', $item->getBinaryEmbedding());
7481
$this->assertEquals([7, 8, 9], $item->getSparseEmbedding()->toArray());
7582
}
83+
84+
public function testVectorL2Distance()
85+
{
86+
$this->createItems();
87+
$rsm = new ResultSetMappingBuilder(self::$em);
88+
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
89+
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <-> ? LIMIT 5', $rsm)
90+
->setParameter(1, new Vector([1, 1, 1]))
91+
->getResult();
92+
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
93+
$this->assertEquals([[1, 1, 1], [1, 1, 2], [2, 2, 2]], array_map(fn ($v) => $v->getEmbedding()->toArray(), $neighbors));
94+
}
95+
96+
public function testVectorMaxInnerProduct()
97+
{
98+
$this->createItems();
99+
$rsm = new ResultSetMappingBuilder(self::$em);
100+
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
101+
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <#> ? LIMIT 5', $rsm)
102+
->setParameter(1, new Vector([1, 1, 1]))
103+
->getResult();
104+
$this->assertEquals([2, 3, 1], array_map(fn ($v) => $v->getId(), $neighbors));
105+
}
106+
107+
public function testVectorCosineDistance()
108+
{
109+
$this->createItems();
110+
$rsm = new ResultSetMappingBuilder(self::$em);
111+
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
112+
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <=> ? LIMIT 5', $rsm)
113+
->setParameter(1, new Vector([1, 1, 1]))
114+
->getResult();
115+
$this->assertEquals([1, 2, 3], array_map(fn ($v) => $v->getId(), $neighbors));
116+
}
117+
118+
private function createItems()
119+
{
120+
foreach ([[1, 1, 1], [2, 2, 2], [1, 1, 2]] as $i => $v) {
121+
$item = new DoctrineItem();
122+
$item->setEmbedding(new Vector($v));
123+
self::$em->persist($item);
124+
}
125+
self::$em->flush();
126+
}
76127
}

tests/models/DoctrineItem.php

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,56 +12,66 @@ class DoctrineItem
1212
#[ORM\Id]
1313
#[ORM\Column(type: 'integer')]
1414
#[ORM\GeneratedValue]
15-
private int|null $id = null;
15+
private ?int $id = null;
1616

17-
#[ORM\Column(type: 'vector', length: 3)]
18-
private Vector $embedding;
17+
#[ORM\Column(type: 'vector', length: 3, nullable: true)]
18+
private ?Vector $embedding;
1919

20-
#[ORM\Column(type: 'halfvec', length: 3)]
21-
private HalfVector $halfEmbedding;
20+
#[ORM\Column(type: 'halfvec', length: 3, nullable: true)]
21+
private ?HalfVector $halfEmbedding;
2222

23-
#[ORM\Column(type: 'bit', length: 3)]
24-
private string $binaryEmbedding;
23+
#[ORM\Column(type: 'bit', length: 3, nullable: true)]
24+
private ?string $binaryEmbedding;
2525

26-
#[ORM\Column(type: 'sparsevec', length: 3)]
27-
private SparseVector $sparseEmbedding;
26+
#[ORM\Column(type: 'sparsevec', length: 3, nullable: true)]
27+
private ?SparseVector $sparseEmbedding;
2828

29-
public function getEmbedding(): Vector
29+
public function getId(): ?int
30+
{
31+
return $this->id;
32+
}
33+
34+
public function setId(?int $id): void
35+
{
36+
$this->id = $id;
37+
}
38+
39+
public function getEmbedding(): ?Vector
3040
{
3141
return $this->embedding;
3242
}
3343

34-
public function setEmbedding(Vector $embedding): void
44+
public function setEmbedding(?Vector $embedding): void
3545
{
3646
$this->embedding = $embedding;
3747
}
3848

39-
public function getHalfEmbedding(): HalfVector
49+
public function getHalfEmbedding(): ?HalfVector
4050
{
4151
return $this->halfEmbedding;
4252
}
4353

44-
public function setHalfEmbedding(HalfVector $embedding): void
54+
public function setHalfEmbedding(?HalfVector $embedding): void
4555
{
4656
$this->halfEmbedding = $embedding;
4757
}
4858

49-
public function getBinaryEmbedding(): string
59+
public function getBinaryEmbedding(): ?string
5060
{
5161
return $this->binaryEmbedding;
5262
}
5363

54-
public function setBinaryEmbedding(string $embedding): void
64+
public function setBinaryEmbedding(?string $embedding): void
5565
{
5666
$this->binaryEmbedding = $embedding;
5767
}
5868

59-
public function getSparseEmbedding(): SparseVector
69+
public function getSparseEmbedding(): ?SparseVector
6070
{
6171
return $this->sparseEmbedding;
6272
}
6373

64-
public function setSparseEmbedding(SparseVector $embedding): void
74+
public function setSparseEmbedding(?SparseVector $embedding): void
6575
{
6676
$this->sparseEmbedding = $embedding;
6777
}

0 commit comments

Comments
 (0)