Skip to content

Commit 47d5b28

Browse files
committed
Improved support functions for HNSW - #527
1 parent 2bf1175 commit 47d5b28

7 files changed

Lines changed: 96 additions & 99 deletions

File tree

sql/vector--0.6.2--0.7.0.sql

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,13 @@ CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
2828
CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
2929
AS 'MODULE_PATHNAME' LANGUAGE C;
3030

31-
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
31+
CREATE FUNCTION hnsw_bit_support(internal) RETURNS internal
3232
AS 'MODULE_PATHNAME' LANGUAGE C;
3333

34-
CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal
34+
CREATE FUNCTION hnsw_halfvec_support(internal) RETURNS internal
3535
AS 'MODULE_PATHNAME' LANGUAGE C;
3636

37-
CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal
38-
AS 'MODULE_PATHNAME' LANGUAGE C;
39-
40-
CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal
37+
CREATE FUNCTION hnsw_sparsevec_support(internal) RETURNS internal
4138
AS 'MODULE_PATHNAME' LANGUAGE C;
4239

4340
CREATE OPERATOR CLASS vector_l1_ops
@@ -72,13 +69,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
7269
FOR TYPE bit USING hnsw AS
7370
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
7471
FUNCTION 1 hamming_distance(bit, bit),
75-
FUNCTION 4 hnsw_bit_max_dims(internal);
72+
FUNCTION 4 hnsw_bit_support(internal);
7673

7774
CREATE OPERATOR CLASS bit_jaccard_ops
7875
FOR TYPE bit USING hnsw AS
7976
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
8077
FUNCTION 1 jaccard_distance(bit, bit),
81-
FUNCTION 4 hnsw_bit_max_dims(internal);
78+
FUNCTION 4 hnsw_bit_support(internal);
8279

8380
CREATE TYPE halfvec;
8481

@@ -358,27 +355,27 @@ CREATE OPERATOR CLASS halfvec_l2_ops
358355
FOR TYPE halfvec USING hnsw AS
359356
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
360357
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
361-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
358+
FUNCTION 4 hnsw_halfvec_support(internal);
362359

363360
CREATE OPERATOR CLASS halfvec_ip_ops
364361
FOR TYPE halfvec USING hnsw AS
365362
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
366363
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
367-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
364+
FUNCTION 4 hnsw_halfvec_support(internal);
368365

369366
CREATE OPERATOR CLASS halfvec_cosine_ops
370367
FOR TYPE halfvec USING hnsw AS
371368
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
372369
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
373370
FUNCTION 2 l2_norm(halfvec),
374371
FUNCTION 3 l2_normalize(halfvec),
375-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
372+
FUNCTION 4 hnsw_halfvec_support(internal);
376373

377374
CREATE OPERATOR CLASS halfvec_l1_ops
378375
FOR TYPE halfvec USING hnsw AS
379376
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
380377
FUNCTION 1 l1_distance(halfvec, halfvec),
381-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
378+
FUNCTION 4 hnsw_halfvec_support(internal);
382379

383380
CREATE TYPE sparsevec;
384381

@@ -550,28 +547,24 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
550547
FOR TYPE sparsevec USING hnsw AS
551548
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
552549
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
553-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
554-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
550+
FUNCTION 4 hnsw_sparsevec_support(internal);
555551

556552
CREATE OPERATOR CLASS sparsevec_ip_ops
557553
FOR TYPE sparsevec USING hnsw AS
558554
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
559555
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
560-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
561-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
556+
FUNCTION 4 hnsw_sparsevec_support(internal);
562557

563558
CREATE OPERATOR CLASS sparsevec_cosine_ops
564559
FOR TYPE sparsevec USING hnsw AS
565560
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
566561
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
567562
FUNCTION 2 l2_norm(sparsevec),
568563
FUNCTION 3 l2_normalize(sparsevec),
569-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
570-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
564+
FUNCTION 4 hnsw_sparsevec_support(internal);
571565

572566
CREATE OPERATOR CLASS sparsevec_l1_ops
573567
FOR TYPE sparsevec USING hnsw AS
574568
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
575569
FUNCTION 1 l1_distance(sparsevec, sparsevec),
576-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
577-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
570+
FUNCTION 4 hnsw_sparsevec_support(internal);

sql/vector.sql

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,13 @@ CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
269269
CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
270270
AS 'MODULE_PATHNAME' LANGUAGE C;
271271

272-
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
272+
CREATE FUNCTION hnsw_bit_support(internal) RETURNS internal
273273
AS 'MODULE_PATHNAME' LANGUAGE C;
274274

275-
CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal
275+
CREATE FUNCTION hnsw_halfvec_support(internal) RETURNS internal
276276
AS 'MODULE_PATHNAME' LANGUAGE C;
277277

278-
CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal
279-
AS 'MODULE_PATHNAME' LANGUAGE C;
280-
281-
CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal
278+
CREATE FUNCTION hnsw_sparsevec_support(internal) RETURNS internal
282279
AS 'MODULE_PATHNAME' LANGUAGE C;
283280

284281
-- vector opclasses
@@ -367,13 +364,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
367364
FOR TYPE bit USING hnsw AS
368365
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
369366
FUNCTION 1 hamming_distance(bit, bit),
370-
FUNCTION 4 hnsw_bit_max_dims(internal);
367+
FUNCTION 4 hnsw_bit_support(internal);
371368

372369
CREATE OPERATOR CLASS bit_jaccard_ops
373370
FOR TYPE bit USING hnsw AS
374371
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
375372
FUNCTION 1 jaccard_distance(bit, bit),
376-
FUNCTION 4 hnsw_bit_max_dims(internal);
373+
FUNCTION 4 hnsw_bit_support(internal);
377374

378375
-- halfvec type
379376

@@ -669,27 +666,27 @@ CREATE OPERATOR CLASS halfvec_l2_ops
669666
FOR TYPE halfvec USING hnsw AS
670667
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
671668
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
672-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
669+
FUNCTION 4 hnsw_halfvec_support(internal);
673670

674671
CREATE OPERATOR CLASS halfvec_ip_ops
675672
FOR TYPE halfvec USING hnsw AS
676673
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
677674
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
678-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
675+
FUNCTION 4 hnsw_halfvec_support(internal);
679676

680677
CREATE OPERATOR CLASS halfvec_cosine_ops
681678
FOR TYPE halfvec USING hnsw AS
682679
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
683680
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
684681
FUNCTION 2 l2_norm(halfvec),
685682
FUNCTION 3 l2_normalize(halfvec),
686-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
683+
FUNCTION 4 hnsw_halfvec_support(internal);
687684

688685
CREATE OPERATOR CLASS halfvec_l1_ops
689686
FOR TYPE halfvec USING hnsw AS
690687
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
691688
FUNCTION 1 l1_distance(halfvec, halfvec),
692-
FUNCTION 4 hnsw_halfvec_max_dims(internal);
689+
FUNCTION 4 hnsw_halfvec_support(internal);
693690

694691
--- sparsevec type
695692

@@ -875,28 +872,24 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
875872
FOR TYPE sparsevec USING hnsw AS
876873
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
877874
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
878-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
879-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
875+
FUNCTION 4 hnsw_sparsevec_support(internal);
880876

881877
CREATE OPERATOR CLASS sparsevec_ip_ops
882878
FOR TYPE sparsevec USING hnsw AS
883879
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
884880
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
885-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
886-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
881+
FUNCTION 4 hnsw_sparsevec_support(internal);
887882

888883
CREATE OPERATOR CLASS sparsevec_cosine_ops
889884
FOR TYPE sparsevec USING hnsw AS
890885
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
891886
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
892887
FUNCTION 2 l2_norm(sparsevec),
893888
FUNCTION 3 l2_normalize(sparsevec),
894-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
895-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
889+
FUNCTION 4 hnsw_sparsevec_support(internal);
896890

897891
CREATE OPERATOR CLASS sparsevec_l1_ops
898892
FOR TYPE sparsevec USING hnsw AS
899893
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
900894
FUNCTION 1 l1_distance(sparsevec, sparsevec),
901-
FUNCTION 4 hnsw_sparsevec_max_dims(internal),
902-
FUNCTION 5 hnsw_sparsevec_check_value(internal);
895+
FUNCTION 4 hnsw_sparsevec_support(internal);

src/hnsw.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS)
194194
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
195195

196196
amroutine->amstrategies = 0;
197-
amroutine->amsupport = 5;
197+
amroutine->amsupport = 4;
198198
#if PG_VERSION_NUM >= 130000
199199
amroutine->amoptsprocnum = 0;
200200
#endif

src/hnsw.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
#define HNSW_DISTANCE_PROC 1
2424
#define HNSW_NORM_PROC 2
2525
#define HNSW_NORMALIZE_PROC 3
26-
#define HNSW_MAX_DIMS_PROC 4
27-
#define HNSW_CHECK_VALUE_PROC 5
26+
#define HNSW_TYPE_INFO_PROC 4
2827

2928
#define HNSW_VERSION 1
3029
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -239,13 +238,20 @@ typedef struct HnswAllocator
239238
void *state;
240239
} HnswAllocator;
241240

241+
typedef struct HnswTypeInfo
242+
{
243+
int maxDimensions;
244+
void (*checkValue) (Pointer v);
245+
} HnswTypeInfo;
246+
242247
typedef struct HnswBuildState
243248
{
244249
/* Info */
245250
Relation heap;
246251
Relation index;
247252
IndexInfo *indexInfo;
248253
ForkNumber forkNum;
254+
const HnswTypeInfo *typeInfo;
249255

250256
/* Settings */
251257
int dimensions;
@@ -260,7 +266,6 @@ typedef struct HnswBuildState
260266
FmgrInfo *procinfo;
261267
FmgrInfo *normprocinfo;
262268
FmgrInfo *normalizeprocinfo;
263-
FmgrInfo *checkvalueprocinfo;
264269
Oid collation;
265270

266271
/* Variables */
@@ -375,7 +380,6 @@ int HnswGetEfConstruction(Relation index);
375380
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
376381
Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
377382
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
378-
void HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value);
379383
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
380384
void HnswInitPage(Buffer buf, Page page);
381385
void HnswInit(void);
@@ -399,6 +403,7 @@ void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element
399403
void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation);
400404
void HnswLoadNeighbors(HnswElement element, Relation index, int m);
401405
void HnswInitLockTranche(void);
406+
const HnswTypeInfo *HnswGetTypeInfo(Relation index);
402407
PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc);
403408

404409
/* Index access methods */

src/hnswbuild.c

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
488488
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
489489

490490
/* Check value */
491-
if (buildstate->checkvalueprocinfo != NULL)
492-
HnswCheckValue(buildstate->checkvalueprocinfo, buildstate->collation, value);
491+
if (buildstate->typeInfo->checkValue != NULL)
492+
buildstate->typeInfo->checkValue(DatumGetPointer(value));
493493

494494
/* Normalize if needed */
495495
if (buildstate->normprocinfo != NULL)
@@ -672,32 +672,17 @@ HnswSharedMemoryAlloc(Size size, void *state)
672672
return chunk;
673673
}
674674

675-
/*
676-
* Get max dimensions
677-
*/
678-
static int
679-
GetMaxDimensions(Relation index)
680-
{
681-
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_MAX_DIMS_PROC);
682-
683-
if (procinfo == NULL)
684-
return HNSW_MAX_DIM;
685-
686-
return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL)));
687-
}
688-
689675
/*
690676
* Initialize the build state
691677
*/
692678
static void
693679
InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum)
694680
{
695-
int maxDimensions;
696-
697681
buildstate->heap = heap;
698682
buildstate->index = index;
699683
buildstate->indexInfo = indexInfo;
700684
buildstate->forkNum = forkNum;
685+
buildstate->typeInfo = HnswGetTypeInfo(index);
701686

702687
buildstate->m = HnswGetM(index);
703688
buildstate->efConstruction = HnswGetEfConstruction(index);
@@ -707,14 +692,12 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
707692
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
708693
elog(ERROR, "type not supported for hnsw index");
709694

710-
maxDimensions = GetMaxDimensions(index);
711-
712695
/* Require column to have dimensions to be indexed */
713696
if (buildstate->dimensions < 0)
714697
elog(ERROR, "column does not have dimensions");
715698

716-
if (buildstate->dimensions > maxDimensions)
717-
elog(ERROR, "column cannot have more than %d dimensions for hnsw index", maxDimensions);
699+
if (buildstate->dimensions > buildstate->typeInfo->maxDimensions)
700+
elog(ERROR, "column cannot have more than %d dimensions for hnsw index", buildstate->typeInfo->maxDimensions);
718701

719702
if (buildstate->efConstruction < 2 * buildstate->m)
720703
elog(ERROR, "ef_construction must be greater than or equal to 2 * m");
@@ -726,7 +709,6 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
726709
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
727710
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
728711
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
729-
buildstate->checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
730712
buildstate->collation = index->rd_indcollation[0];
731713

732714
InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L);

src/hnswinsert.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,16 +612,16 @@ static void
612612
HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid)
613613
{
614614
Datum value;
615-
FmgrInfo *checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
615+
const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index);
616616
FmgrInfo *normprocinfo;
617617
Oid collation = index->rd_indcollation[0];
618618

619619
/* Detoast once for all calls */
620620
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
621621

622622
/* Check value */
623-
if (checkvalueprocinfo != NULL)
624-
HnswCheckValue(checkvalueprocinfo, collation, value);
623+
if (typeInfo->checkValue != NULL)
624+
typeInfo->checkValue(DatumGetPointer(value));
625625

626626
/* Normalize if needed */
627627
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);

0 commit comments

Comments
 (0)