Skip to content

Commit 3eef1ff

Browse files
committed
Removed type-specific code from HNSW [skip ci]
1 parent b8bdf31 commit 3eef1ff

6 files changed

Lines changed: 64 additions & 83 deletions

File tree

sql/vector.sql

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
327327
CREATE FUNCTION bit_ivfflat_support(internal) RETURNS internal
328328
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
329329

330-
CREATE FUNCTION bit_hnsw_support(internal) RETURNS internal
330+
CREATE FUNCTION bit_hnsw_max_dims(internal) RETURNS internal
331331
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
332332

333333
-- bit operators
@@ -355,13 +355,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
355355
FOR TYPE bit USING hnsw AS
356356
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
357357
FUNCTION 1 hamming_distance(bit, bit),
358-
FUNCTION 4 bit_hnsw_support(internal);
358+
FUNCTION 4 bit_hnsw_max_dims(internal);
359359

360360
CREATE OPERATOR CLASS bit_jaccard_ops
361361
FOR TYPE bit USING hnsw AS
362362
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
363363
FUNCTION 1 jaccard_distance(bit, bit),
364-
FUNCTION 4 bit_hnsw_support(internal);
364+
FUNCTION 4 bit_hnsw_max_dims(internal);
365365

366366
-- halfvec type
367367

@@ -473,7 +473,7 @@ CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
473473
CREATE FUNCTION halfvec_ivfflat_support(internal) RETURNS internal
474474
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
475475

476-
CREATE FUNCTION halfvec_hnsw_support(internal) RETURNS internal
476+
CREATE FUNCTION halfvec_hnsw_max_dims(internal) RETURNS internal
477477
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
478478

479479
-- halfvec aggregates
@@ -663,27 +663,27 @@ CREATE OPERATOR CLASS halfvec_l2_ops
663663
FOR TYPE halfvec USING hnsw AS
664664
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
665665
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
666-
FUNCTION 4 halfvec_hnsw_support(internal);
666+
FUNCTION 4 halfvec_hnsw_max_dims(internal);
667667

668668
CREATE OPERATOR CLASS halfvec_ip_ops
669669
FOR TYPE halfvec USING hnsw AS
670670
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
671671
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
672-
FUNCTION 4 halfvec_hnsw_support(internal);
672+
FUNCTION 4 halfvec_hnsw_max_dims(internal);
673673

674674
CREATE OPERATOR CLASS halfvec_cosine_ops
675675
FOR TYPE halfvec USING hnsw AS
676676
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
677677
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
678678
FUNCTION 2 l2_norm(halfvec),
679679
FUNCTION 3 l2_normalize(halfvec),
680-
FUNCTION 4 halfvec_hnsw_support(internal);
680+
FUNCTION 4 halfvec_hnsw_max_dims(internal);
681681

682682
CREATE OPERATOR CLASS halfvec_l1_ops
683683
FOR TYPE halfvec USING hnsw AS
684684
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
685685
FUNCTION 1 l1_distance(halfvec, halfvec),
686-
FUNCTION 4 halfvec_hnsw_support(internal);
686+
FUNCTION 4 halfvec_hnsw_max_dims(internal);
687687

688688
--- sparsevec type
689689

@@ -779,7 +779,10 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve
779779
CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec
780780
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
781781

782-
CREATE FUNCTION sparsevec_hnsw_support(internal) RETURNS internal
782+
CREATE FUNCTION sparsevec_hnsw_max_dims(internal) RETURNS internal
783+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
784+
785+
CREATE FUNCTION sparsevec_hnsw_check_value(internal) RETURNS internal
783786
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
784787

785788
-- sparsevec casts
@@ -872,24 +875,28 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
872875
FOR TYPE sparsevec USING hnsw AS
873876
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
874877
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
875-
FUNCTION 4 sparsevec_hnsw_support(internal);
878+
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
879+
FUNCTION 5 sparsevec_hnsw_check_value(internal);
876880

877881
CREATE OPERATOR CLASS sparsevec_ip_ops
878882
FOR TYPE sparsevec USING hnsw AS
879883
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
880884
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
881-
FUNCTION 4 sparsevec_hnsw_support(internal);
885+
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
886+
FUNCTION 5 sparsevec_hnsw_check_value(internal);
882887

883888
CREATE OPERATOR CLASS sparsevec_cosine_ops
884889
FOR TYPE sparsevec USING hnsw AS
885890
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
886891
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
887892
FUNCTION 2 l2_norm(sparsevec),
888893
FUNCTION 3 l2_normalize(sparsevec),
889-
FUNCTION 4 sparsevec_hnsw_support(internal);
894+
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
895+
FUNCTION 5 sparsevec_hnsw_check_value(internal);
890896

891897
CREATE OPERATOR CLASS sparsevec_l1_ops
892898
FOR TYPE sparsevec USING hnsw AS
893899
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
894900
FUNCTION 1 l1_distance(sparsevec, sparsevec),
895-
FUNCTION 4 sparsevec_hnsw_support(internal);
901+
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
902+
FUNCTION 5 sparsevec_hnsw_check_value(internal);

src/hnsw.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS)
194194
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
195195

196196
amroutine->amstrategies = 0;
197-
amroutine->amsupport = 4;
197+
amroutine->amsupport = 5;
198198
#if PG_VERSION_NUM >= 130000
199199
amroutine->amoptsprocnum = 0;
200200
#endif

src/hnsw.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
#define HNSW_DISTANCE_PROC 1
2424
#define HNSW_NORM_PROC 2
2525
#define HNSW_NORMALIZE_PROC 3
26-
#define HNSW_TYPE_SUPPORT_PROC 4
26+
#define HNSW_MAX_DIMS_PROC 4
27+
#define HNSW_CHECK_VALUE_PROC 5
2728

2829
#define HNSW_VERSION 1
2930
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -58,15 +59,6 @@
5859
#define HNSW_UPDATE_ENTRY_GREATER 1
5960
#define HNSW_UPDATE_ENTRY_ALWAYS 2
6061

61-
typedef enum HnswType
62-
{
63-
HNSW_TYPE_VECTOR,
64-
HNSW_TYPE_HALFVEC,
65-
HNSW_TYPE_BIT,
66-
HNSW_TYPE_SPARSEVEC,
67-
HNSW_TYPE_UNSUPPORTED
68-
} HnswType;
69-
7062
/* Build phases */
7163
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
7264
#define PROGRESS_HNSW_PHASE_LOAD 2
@@ -254,7 +246,6 @@ typedef struct HnswBuildState
254246
Relation index;
255247
IndexInfo *indexInfo;
256248
ForkNumber forkNum;
257-
HnswType type;
258249

259250
/* Settings */
260251
int dimensions;
@@ -269,6 +260,7 @@ typedef struct HnswBuildState
269260
FmgrInfo *procinfo;
270261
FmgrInfo *normprocinfo;
271262
FmgrInfo *normalizeprocinfo;
263+
FmgrInfo *checkvalueprocinfo;
272264
Oid collation;
273265

274266
/* Variables */
@@ -381,10 +373,9 @@ typedef struct HnswVacuumState
381373
int HnswGetM(Relation index);
382374
int HnswGetEfConstruction(Relation index);
383375
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
384-
HnswType HnswGetType(Relation index);
385376
Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
386377
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
387-
void HnswCheckValue(Datum value, HnswType type);
378+
void HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value);
388379
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
389380
void HnswInitPage(Buffer buf, Page page);
390381
void HnswInit(void);

src/hnswbuild.c

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
488488
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
489489

490490
/* Check value */
491-
HnswCheckValue(value, buildstate->type);
491+
if (buildstate->checkvalueprocinfo != NULL)
492+
HnswCheckValue(buildstate->checkvalueprocinfo, buildstate->collation, value);
492493

493494
/* Normalize if needed */
494495
if (buildstate->normprocinfo != NULL)
@@ -675,18 +676,14 @@ HnswSharedMemoryAlloc(Size size, void *state)
675676
* Get max dimensions
676677
*/
677678
static int
678-
GetMaxDimensions(HnswType type)
679+
GetMaxDimensions(Relation index)
679680
{
680-
int maxDimensions = HNSW_MAX_DIM;
681+
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_MAX_DIMS_PROC);
681682

682-
if (type == HNSW_TYPE_HALFVEC)
683-
maxDimensions *= 2;
684-
else if (type == HNSW_TYPE_BIT)
685-
maxDimensions *= 32;
686-
else if (type == HNSW_TYPE_SPARSEVEC)
687-
maxDimensions = INT_MAX;
683+
if (procinfo == NULL)
684+
return HNSW_MAX_DIM;
688685

689-
return maxDimensions;
686+
return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL)));
690687
}
691688

692689
/*
@@ -701,13 +698,16 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
701698
buildstate->index = index;
702699
buildstate->indexInfo = indexInfo;
703700
buildstate->forkNum = forkNum;
704-
buildstate->type = HnswGetType(index);
705701

706702
buildstate->m = HnswGetM(index);
707703
buildstate->efConstruction = HnswGetEfConstruction(index);
708704
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
709705

710-
maxDimensions = GetMaxDimensions(buildstate->type);
706+
/* Disallow varbit since require fixed dimensions */
707+
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
708+
elog(ERROR, "type not supported for hnsw index");
709+
710+
maxDimensions = GetMaxDimensions(index);
711711

712712
/* Require column to have dimensions to be indexed */
713713
if (buildstate->dimensions < 0)
@@ -726,6 +726,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
726726
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
727727
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
728728
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
729+
buildstate->checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
729730
buildstate->collation = index->rd_indcollation[0];
730731

731732
InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L);

src/hnswinsert.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,14 +612,16 @@ static void
612612
HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid)
613613
{
614614
Datum value;
615+
FmgrInfo *checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
615616
FmgrInfo *normprocinfo;
616617
Oid collation = index->rd_indcollation[0];
617618

618619
/* Detoast once for all calls */
619620
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
620621

621622
/* Check value */
622-
HnswCheckValue(value, HnswGetType(index));
623+
if (checkvalueprocinfo != NULL)
624+
HnswCheckValue(checkvalueprocinfo, collation, value);
623625

624626
/* Normalize if needed */
625627
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);

src/hnswutils.c

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -152,27 +152,6 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
152152
return index_getprocinfo(index, 1, procnum);
153153
}
154154

155-
/*
156-
* Get type
157-
*/
158-
HnswType
159-
HnswGetType(Relation index)
160-
{
161-
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_TYPE_SUPPORT_PROC);
162-
Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid;
163-
HnswType result;
164-
165-
if (procinfo == NULL)
166-
return HNSW_TYPE_VECTOR;
167-
168-
result = (HnswType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid)));
169-
170-
if (result == HNSW_TYPE_UNSUPPORTED)
171-
elog(ERROR, "type not supported for hnsw index");
172-
173-
return result;
174-
}
175-
176155
/*
177156
* Normalize value
178157
*/
@@ -198,15 +177,9 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
198177
* Check if a value can be indexed
199178
*/
200179
void
201-
HnswCheckValue(Datum value, HnswType type)
180+
HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value)
202181
{
203-
if (type == HNSW_TYPE_SPARSEVEC)
204-
{
205-
SparseVector *vec = DatumGetSparseVector(value);
206-
207-
if (vec->nnz > HNSW_MAX_NNZ)
208-
elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ);
209-
}
182+
FunctionCall1Coll(procinfo, collation, value);
210183
}
211184

212185
/*
@@ -1303,28 +1276,35 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
13031276
}
13041277
}
13051278

1306-
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_support);
1279+
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_max_dims);
13071280
Datum
1308-
halfvec_hnsw_support(PG_FUNCTION_ARGS)
1281+
halfvec_hnsw_max_dims(PG_FUNCTION_ARGS)
13091282
{
1310-
PG_RETURN_INT32(HNSW_TYPE_HALFVEC);
1283+
PG_RETURN_INT32(HNSW_MAX_DIM * 2);
13111284
};
13121285

1313-
PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_support);
1286+
PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_max_dims);
13141287
Datum
1315-
bit_hnsw_support(PG_FUNCTION_ARGS)
1288+
bit_hnsw_max_dims(PG_FUNCTION_ARGS)
13161289
{
1317-
Oid typid = PG_GETARG_OID(0);
1318-
1319-
if (typid == BITOID)
1320-
PG_RETURN_INT32(HNSW_TYPE_BIT);
1321-
else
1322-
PG_RETURN_INT32(HNSW_TYPE_UNSUPPORTED);
1290+
PG_RETURN_INT32(HNSW_MAX_DIM * 32);
13231291
};
13241292

1325-
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_support);
1293+
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_max_dims);
13261294
Datum
1327-
sparsevec_hnsw_support(PG_FUNCTION_ARGS)
1295+
sparsevec_hnsw_max_dims(PG_FUNCTION_ARGS)
13281296
{
1329-
PG_RETURN_INT32(HNSW_TYPE_SPARSEVEC);
1297+
PG_RETURN_INT32(SPARSEVEC_MAX_DIM);
13301298
};
1299+
1300+
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_check_value);
1301+
Datum
1302+
sparsevec_hnsw_check_value(PG_FUNCTION_ARGS)
1303+
{
1304+
SparseVector *vec = PG_GETARG_SPARSEVEC_P(0);
1305+
1306+
if (vec->nnz > HNSW_MAX_NNZ)
1307+
elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ);
1308+
1309+
PG_RETURN_VOID();
1310+
}

0 commit comments

Comments
 (0)