Skip to content

Commit 58ec529

Browse files
committed
Reduced support functions for HNSW - pgvector#527
1 parent 47d5b28 commit 58ec529

8 files changed

Lines changed: 39 additions & 49 deletions

File tree

sql/vector--0.6.2--0.7.0.sql

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
6969
FOR TYPE bit USING hnsw AS
7070
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
7171
FUNCTION 1 hamming_distance(bit, bit),
72-
FUNCTION 4 hnsw_bit_support(internal);
72+
FUNCTION 3 hnsw_bit_support(internal);
7373

7474
CREATE OPERATOR CLASS bit_jaccard_ops
7575
FOR TYPE bit USING hnsw AS
7676
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
7777
FUNCTION 1 jaccard_distance(bit, bit),
78-
FUNCTION 4 hnsw_bit_support(internal);
78+
FUNCTION 3 hnsw_bit_support(internal);
7979

8080
CREATE TYPE halfvec;
8181

@@ -355,27 +355,26 @@ CREATE OPERATOR CLASS halfvec_l2_ops
355355
FOR TYPE halfvec USING hnsw AS
356356
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
357357
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
358-
FUNCTION 4 hnsw_halfvec_support(internal);
358+
FUNCTION 3 hnsw_halfvec_support(internal);
359359

360360
CREATE OPERATOR CLASS halfvec_ip_ops
361361
FOR TYPE halfvec USING hnsw AS
362362
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
363363
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
364-
FUNCTION 4 hnsw_halfvec_support(internal);
364+
FUNCTION 3 hnsw_halfvec_support(internal);
365365

366366
CREATE OPERATOR CLASS halfvec_cosine_ops
367367
FOR TYPE halfvec USING hnsw AS
368368
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
369369
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
370370
FUNCTION 2 l2_norm(halfvec),
371-
FUNCTION 3 l2_normalize(halfvec),
372-
FUNCTION 4 hnsw_halfvec_support(internal);
371+
FUNCTION 3 hnsw_halfvec_support(internal);
373372

374373
CREATE OPERATOR CLASS halfvec_l1_ops
375374
FOR TYPE halfvec USING hnsw AS
376375
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
377376
FUNCTION 1 l1_distance(halfvec, halfvec),
378-
FUNCTION 4 hnsw_halfvec_support(internal);
377+
FUNCTION 3 hnsw_halfvec_support(internal);
379378

380379
CREATE TYPE sparsevec;
381380

@@ -547,24 +546,23 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
547546
FOR TYPE sparsevec USING hnsw AS
548547
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
549548
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
550-
FUNCTION 4 hnsw_sparsevec_support(internal);
549+
FUNCTION 3 hnsw_sparsevec_support(internal);
551550

552551
CREATE OPERATOR CLASS sparsevec_ip_ops
553552
FOR TYPE sparsevec USING hnsw AS
554553
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
555554
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
556-
FUNCTION 4 hnsw_sparsevec_support(internal);
555+
FUNCTION 3 hnsw_sparsevec_support(internal);
557556

558557
CREATE OPERATOR CLASS sparsevec_cosine_ops
559558
FOR TYPE sparsevec USING hnsw AS
560559
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
561560
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
562561
FUNCTION 2 l2_norm(sparsevec),
563-
FUNCTION 3 l2_normalize(sparsevec),
564-
FUNCTION 4 hnsw_sparsevec_support(internal);
562+
FUNCTION 3 hnsw_sparsevec_support(internal);
565563

566564
CREATE OPERATOR CLASS sparsevec_l1_ops
567565
FOR TYPE sparsevec USING hnsw AS
568566
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
569567
FUNCTION 1 l1_distance(sparsevec, sparsevec),
570-
FUNCTION 4 hnsw_sparsevec_support(internal);
568+
FUNCTION 3 hnsw_sparsevec_support(internal);

sql/vector.sql

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
364364
FOR TYPE bit USING hnsw AS
365365
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
366366
FUNCTION 1 hamming_distance(bit, bit),
367-
FUNCTION 4 hnsw_bit_support(internal);
367+
FUNCTION 3 hnsw_bit_support(internal);
368368

369369
CREATE OPERATOR CLASS bit_jaccard_ops
370370
FOR TYPE bit USING hnsw AS
371371
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
372372
FUNCTION 1 jaccard_distance(bit, bit),
373-
FUNCTION 4 hnsw_bit_support(internal);
373+
FUNCTION 3 hnsw_bit_support(internal);
374374

375375
-- halfvec type
376376

@@ -666,27 +666,26 @@ CREATE OPERATOR CLASS halfvec_l2_ops
666666
FOR TYPE halfvec USING hnsw AS
667667
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
668668
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
669-
FUNCTION 4 hnsw_halfvec_support(internal);
669+
FUNCTION 3 hnsw_halfvec_support(internal);
670670

671671
CREATE OPERATOR CLASS halfvec_ip_ops
672672
FOR TYPE halfvec USING hnsw AS
673673
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
674674
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
675-
FUNCTION 4 hnsw_halfvec_support(internal);
675+
FUNCTION 3 hnsw_halfvec_support(internal);
676676

677677
CREATE OPERATOR CLASS halfvec_cosine_ops
678678
FOR TYPE halfvec USING hnsw AS
679679
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
680680
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
681681
FUNCTION 2 l2_norm(halfvec),
682-
FUNCTION 3 l2_normalize(halfvec),
683-
FUNCTION 4 hnsw_halfvec_support(internal);
682+
FUNCTION 3 hnsw_halfvec_support(internal);
684683

685684
CREATE OPERATOR CLASS halfvec_l1_ops
686685
FOR TYPE halfvec USING hnsw AS
687686
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
688687
FUNCTION 1 l1_distance(halfvec, halfvec),
689-
FUNCTION 4 hnsw_halfvec_support(internal);
688+
FUNCTION 3 hnsw_halfvec_support(internal);
690689

691690
--- sparsevec type
692691

@@ -872,24 +871,23 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
872871
FOR TYPE sparsevec USING hnsw AS
873872
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
874873
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
875-
FUNCTION 4 hnsw_sparsevec_support(internal);
874+
FUNCTION 3 hnsw_sparsevec_support(internal);
876875

877876
CREATE OPERATOR CLASS sparsevec_ip_ops
878877
FOR TYPE sparsevec USING hnsw AS
879878
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
880879
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
881-
FUNCTION 4 hnsw_sparsevec_support(internal);
880+
FUNCTION 3 hnsw_sparsevec_support(internal);
882881

883882
CREATE OPERATOR CLASS sparsevec_cosine_ops
884883
FOR TYPE sparsevec USING hnsw AS
885884
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
886885
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
887886
FUNCTION 2 l2_norm(sparsevec),
888-
FUNCTION 3 l2_normalize(sparsevec),
889-
FUNCTION 4 hnsw_sparsevec_support(internal);
887+
FUNCTION 3 hnsw_sparsevec_support(internal);
890888

891889
CREATE OPERATOR CLASS sparsevec_l1_ops
892890
FOR TYPE sparsevec USING hnsw AS
893891
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
894892
FUNCTION 1 l1_distance(sparsevec, sparsevec),
895-
FUNCTION 4 hnsw_sparsevec_support(internal);
893+
FUNCTION 3 hnsw_sparsevec_support(internal);

src/hnsw.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS)
194194
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
195195

196196
amroutine->amstrategies = 0;
197-
amroutine->amsupport = 4;
197+
amroutine->amsupport = 3;
198198
#if PG_VERSION_NUM >= 130000
199199
amroutine->amoptsprocnum = 0;
200200
#endif

src/hnsw.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
/* Support functions */
2323
#define HNSW_DISTANCE_PROC 1
2424
#define HNSW_NORM_PROC 2
25-
#define HNSW_NORMALIZE_PROC 3
26-
#define HNSW_TYPE_INFO_PROC 4
25+
#define HNSW_TYPE_INFO_PROC 3
2726

2827
#define HNSW_VERSION 1
2928
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -241,6 +240,7 @@ typedef struct HnswAllocator
241240
typedef struct HnswTypeInfo
242241
{
243242
int maxDimensions;
243+
Datum (*normalize) (PG_FUNCTION_ARGS);
244244
void (*checkValue) (Pointer v);
245245
} HnswTypeInfo;
246246

@@ -265,7 +265,6 @@ typedef struct HnswBuildState
265265
/* Support functions */
266266
FmgrInfo *procinfo;
267267
FmgrInfo *normprocinfo;
268-
FmgrInfo *normalizeprocinfo;
269268
Oid collation;
270269

271270
/* Variables */
@@ -335,14 +334,14 @@ typedef HnswNeighborTupleData * HnswNeighborTuple;
335334

336335
typedef struct HnswScanOpaqueData
337336
{
337+
const HnswTypeInfo *typeInfo;
338338
bool first;
339339
List *w;
340340
MemoryContext tmpCtx;
341341

342342
/* Support functions */
343343
FmgrInfo *procinfo;
344344
FmgrInfo *normprocinfo;
345-
FmgrInfo *normalizeprocinfo;
346345
Oid collation;
347346
} HnswScanOpaqueData;
348347

@@ -378,7 +377,6 @@ typedef struct HnswVacuumState
378377
int HnswGetM(Relation index);
379378
int HnswGetEfConstruction(Relation index);
380379
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
381-
Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
382380
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
383381
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
384382
void HnswInitPage(Buffer buf, Page page);

src/hnswbuild.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element)
476476
static bool
477477
InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, HnswBuildState * buildstate)
478478
{
479+
const HnswTypeInfo *typeInfo = buildstate->typeInfo;
479480
HnswGraph *graph = buildstate->graph;
480481
HnswElement element;
481482
HnswAllocator *allocator = &buildstate->allocator;
@@ -488,16 +489,16 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
488489
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
489490

490491
/* Check value */
491-
if (buildstate->typeInfo->checkValue != NULL)
492-
buildstate->typeInfo->checkValue(DatumGetPointer(value));
492+
if (typeInfo->checkValue != NULL)
493+
typeInfo->checkValue(DatumGetPointer(value));
493494

494495
/* Normalize if needed */
495496
if (buildstate->normprocinfo != NULL)
496497
{
497498
if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
498499
return false;
499500

500-
value = HnswNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
501+
value = DirectFunctionCall1(typeInfo->normalize, value);
501502
}
502503

503504
/* Get datum size */
@@ -708,7 +709,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
708709
/* Get support functions */
709710
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
710711
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
711-
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
712712
buildstate->collation = index->rd_indcollation[0];
713713

714714
InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L);

src/hnswinsert.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
630630
if (!HnswCheckNorm(normprocinfo, collation, value))
631631
return;
632632

633-
value = HnswNormValue(HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC), collation, value);
633+
value = DirectFunctionCall1(typeInfo->normalize, value);
634634
}
635635

636636
HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false);

src/hnswscan.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan)
6161

6262
/* Fine if normalization fails */
6363
if (so->normprocinfo != NULL)
64-
value = HnswNormValue(so->normalizeprocinfo, so->collation, value);
64+
value = DirectFunctionCall1(so->typeInfo->normalize, value);
6565
}
6666

6767
return value;
@@ -79,6 +79,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
7979
scan = RelationGetIndexScan(index, nkeys, norderbys);
8080

8181
so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData));
82+
so->typeInfo = HnswGetTypeInfo(index);
8283
so->first = true;
8384
so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
8485
"Hnsw scan temporary context",
@@ -87,7 +88,6 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
8788
/* Set support functions */
8889
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
8990
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
90-
so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
9191
so->collation = index->rd_indcollation[0];
9292

9393
scan->opaque = so;

src/hnswutils.c

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,6 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
152152
return index_getprocinfo(index, 1, procnum);
153153
}
154154

155-
/*
156-
* Normalize value
157-
*/
158-
Datum
159-
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value)
160-
{
161-
if (procinfo == NULL)
162-
return DirectFunctionCall1(l2_normalize, value);
163-
164-
return FunctionCall1Coll(procinfo, collation, value);
165-
}
166-
167155
/*
168156
* Check if non-zero norm
169157
*/
@@ -1267,6 +1255,10 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
12671255
}
12681256
}
12691257

1258+
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
1259+
PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
1260+
PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
1261+
12701262
static void
12711263
SparsevecCheckValue(Pointer v)
12721264
{
@@ -1288,6 +1280,7 @@ HnswGetTypeInfo(Relation index)
12881280
{
12891281
static const HnswTypeInfo typeInfo = {
12901282
.maxDimensions = HNSW_MAX_DIM,
1283+
.normalize = l2_normalize,
12911284
.checkValue = NULL
12921285
};
12931286

@@ -1303,6 +1296,7 @@ hnsw_halfvec_support(PG_FUNCTION_ARGS)
13031296
{
13041297
static const HnswTypeInfo typeInfo = {
13051298
.maxDimensions = HNSW_MAX_DIM * 2,
1299+
.normalize = halfvec_l2_normalize,
13061300
.checkValue = NULL
13071301
};
13081302

@@ -1315,6 +1309,7 @@ hnsw_bit_support(PG_FUNCTION_ARGS)
13151309
{
13161310
static const HnswTypeInfo typeInfo = {
13171311
.maxDimensions = HNSW_MAX_DIM * 32,
1312+
.normalize = NULL,
13181313
.checkValue = NULL
13191314
};
13201315

@@ -1327,6 +1322,7 @@ hnsw_sparsevec_support(PG_FUNCTION_ARGS)
13271322
{
13281323
static const HnswTypeInfo typeInfo = {
13291324
.maxDimensions = SPARSEVEC_MAX_DIM,
1325+
.normalize = sparsevec_l2_normalize,
13301326
.checkValue = SparsevecCheckValue
13311327
};
13321328

0 commit comments

Comments
 (0)