Skip to content

Commit 872d1a4

Browse files
committed
Added more tests for Doctrine [skip ci]
1 parent 77f1147 commit 872d1a4

1 file changed

Lines changed: 90 additions & 2 deletions

File tree

tests/DoctrineTest.php

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,47 @@ public function testVectorL1Distance()
128128
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
129129
}
130130

131+
public function testHalfvecL2Distance()
132+
{
133+
$this->createItems('halfEmbedding');
134+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY l2_distance(i.halfEmbedding, ?1)')
135+
->setParameter(1, new HalfVector([1, 1, 1]))
136+
->setMaxResults(5)
137+
->getResult();
138+
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
139+
$this->assertEquals([[1, 1, 1], [1, 1, 2], [2, 2, 2]], array_map(fn ($v) => $v->getHalfEmbedding()->toArray(), $neighbors));
140+
}
141+
142+
public function testHalfvecMaxInnerProduct()
143+
{
144+
$this->createItems('halfEmbedding');
145+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY max_inner_product(i.halfEmbedding, ?1)')
146+
->setParameter(1, new HalfVector([1, 1, 1]))
147+
->setMaxResults(5)
148+
->getResult();
149+
$this->assertEquals([2, 3, 1], array_map(fn ($v) => $v->getId(), $neighbors));
150+
}
151+
152+
public function testHalfvecCosineDistance()
153+
{
154+
$this->createItems('halfEmbedding');
155+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY cosine_distance(i.halfEmbedding, ?1)')
156+
->setParameter(1, new HalfVector([1, 1, 1]))
157+
->setMaxResults(5)
158+
->getResult();
159+
$this->assertEquals([1, 2, 3], array_map(fn ($v) => $v->getId(), $neighbors));
160+
}
161+
162+
public function testHalfvecL1Distance()
163+
{
164+
$this->createItems('halfEmbedding');
165+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY l1_distance(i.halfEmbedding, ?1)')
166+
->setParameter(1, new HalfVector([1, 1, 1]))
167+
->setMaxResults(5)
168+
->getResult();
169+
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
170+
}
171+
131172
public function testBitHammingDistance()
132173
{
133174
$this->createBitItems();
@@ -148,11 +189,58 @@ public function testBitJaccardDistance()
148189
$this->assertEquals([2, 3, 1], array_map(fn ($v) => $v->getId(), $neighbors));
149190
}
150191

151-
private function createItems()
192+
public function testSparsevecL2Distance()
193+
{
194+
$this->createItems('sparseEmbedding');
195+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY l2_distance(i.sparseEmbedding, ?1)')
196+
->setParameter(1, new SparseVector([1, 1, 1]))
197+
->setMaxResults(5)
198+
->getResult();
199+
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
200+
$this->assertEquals([[1, 1, 1], [1, 1, 2], [2, 2, 2]], array_map(fn ($v) => $v->getSparseEmbedding()->toArray(), $neighbors));
201+
}
202+
203+
public function testSparsevecMaxInnerProduct()
204+
{
205+
$this->createItems('sparseEmbedding');
206+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY max_inner_product(i.sparseEmbedding, ?1)')
207+
->setParameter(1, new SparseVector([1, 1, 1]))
208+
->setMaxResults(5)
209+
->getResult();
210+
$this->assertEquals([2, 3, 1], array_map(fn ($v) => $v->getId(), $neighbors));
211+
}
212+
213+
public function testSparsevecCosineDistance()
214+
{
215+
$this->createItems('sparseEmbedding');
216+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY cosine_distance(i.sparseEmbedding, ?1)')
217+
->setParameter(1, new SparseVector([1, 1, 1]))
218+
->setMaxResults(5)
219+
->getResult();
220+
$this->assertEquals([1, 2, 3], array_map(fn ($v) => $v->getId(), $neighbors));
221+
}
222+
223+
public function testSparsevecL1Distance()
224+
{
225+
$this->createItems('sparseEmbedding');
226+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY l1_distance(i.sparseEmbedding, ?1)')
227+
->setParameter(1, new SparseVector([1, 1, 1]))
228+
->setMaxResults(5)
229+
->getResult();
230+
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
231+
}
232+
233+
private function createItems($attribute = 'embedding')
152234
{
153235
foreach ([[1, 1, 1], [2, 2, 2], [1, 1, 2]] as $v) {
154236
$item = new DoctrineItem();
155-
$item->setEmbedding(new Vector($v));
237+
if ($attribute == 'halfEmbedding') {
238+
$item->setHalfEmbedding(new HalfVector($v));
239+
} else if ($attribute == 'sparseEmbedding') {
240+
$item->setSparseEmbedding(new SparseVector($v));
241+
} else {
242+
$item->setEmbedding(new Vector($v));
243+
}
156244
self::$em->persist($item);
157245
}
158246
self::$em->flush();

0 commit comments

Comments
 (0)