Skip to content

Commit 8f00d67

Browse files
committed
Removed type-specific code from IVFFlat - pgvector#527
1 parent 52bfedd commit 8f00d67

9 files changed

Lines changed: 216 additions & 313 deletions

File tree

sql/vector--0.6.2--0.7.0.sql

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,40 @@ CREATE OPERATOR || (
2323
);
2424

2525
CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal
26-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
26+
AS 'MODULE_PATHNAME' LANGUAGE C;
2727

2828
CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal
29-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
29+
AS 'MODULE_PATHNAME' LANGUAGE C;
3030

31-
CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
32-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
31+
CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal
32+
AS 'MODULE_PATHNAME' LANGUAGE C;
3333

34-
CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
35-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
34+
CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal
35+
AS 'MODULE_PATHNAME' LANGUAGE C;
36+
37+
CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal
38+
AS 'MODULE_PATHNAME' LANGUAGE C;
39+
40+
CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal
41+
AS 'MODULE_PATHNAME' LANGUAGE C;
42+
43+
CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal
44+
AS 'MODULE_PATHNAME' LANGUAGE C;
45+
46+
CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal
47+
AS 'MODULE_PATHNAME' LANGUAGE C;
3648

3749
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
38-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
50+
AS 'MODULE_PATHNAME' LANGUAGE C;
3951

4052
CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal
41-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
53+
AS 'MODULE_PATHNAME' LANGUAGE C;
4254

4355
CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal
44-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
56+
AS 'MODULE_PATHNAME' LANGUAGE C;
4557

4658
CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal
47-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
59+
AS 'MODULE_PATHNAME' LANGUAGE C;
4860

4961
CREATE OPERATOR CLASS vector_l1_ops
5062
FOR TYPE vector USING hnsw AS
@@ -73,7 +85,8 @@ CREATE OPERATOR CLASS bit_hamming_ops
7385
FUNCTION 1 hamming_distance(bit, bit),
7486
FUNCTION 3 hamming_distance(bit, bit),
7587
FUNCTION 6 ivfflat_bit_max_dims(internal),
76-
FUNCTION 7 ivfflat_bit_support(internal);
88+
FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal),
89+
FUNCTION 8 ivfflat_bit_sum_center(internal, internal);
7790

7891
CREATE OPERATOR CLASS bit_hamming_ops
7992
FOR TYPE bit USING hnsw AS
@@ -341,7 +354,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops
341354
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
342355
FUNCTION 3 l2_distance(halfvec, halfvec),
343356
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
344-
FUNCTION 7 ivfflat_halfvec_support(internal);
357+
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
358+
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
345359

346360
CREATE OPERATOR CLASS halfvec_ip_ops
347361
FOR TYPE halfvec USING ivfflat AS
@@ -351,7 +365,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops
351365
FUNCTION 4 l2_norm(halfvec),
352366
FUNCTION 5 l2_normalize(halfvec),
353367
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
354-
FUNCTION 7 ivfflat_halfvec_support(internal);
368+
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
369+
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
355370

356371
CREATE OPERATOR CLASS halfvec_cosine_ops
357372
FOR TYPE halfvec USING ivfflat AS
@@ -362,7 +377,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
362377
FUNCTION 4 l2_norm(halfvec),
363378
FUNCTION 5 l2_normalize(halfvec),
364379
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
365-
FUNCTION 7 ivfflat_halfvec_support(internal);
380+
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
381+
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
366382

367383
CREATE OPERATOR CLASS halfvec_l2_ops
368384
FOR TYPE halfvec USING hnsw AS

sql/vector.sql

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,28 +264,40 @@ COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
264264
-- access method private functions
265265

266266
CREATE FUNCTION ivfflat_bit_max_dims(internal) RETURNS internal
267-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
267+
AS 'MODULE_PATHNAME' LANGUAGE C;
268268

269269
CREATE FUNCTION ivfflat_halfvec_max_dims(internal) RETURNS internal
270-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
270+
AS 'MODULE_PATHNAME' LANGUAGE C;
271271

272-
CREATE FUNCTION ivfflat_bit_support(internal) RETURNS internal
273-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
272+
CREATE FUNCTION ivfflat_vector_update_center(internal, internal, internal) RETURNS internal
273+
AS 'MODULE_PATHNAME' LANGUAGE C;
274274

275-
CREATE FUNCTION ivfflat_halfvec_support(internal) RETURNS internal
276-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
275+
CREATE FUNCTION ivfflat_bit_update_center(internal, internal, internal) RETURNS internal
276+
AS 'MODULE_PATHNAME' LANGUAGE C;
277+
278+
CREATE FUNCTION ivfflat_halfvec_update_center(internal, internal, internal) RETURNS internal
279+
AS 'MODULE_PATHNAME' LANGUAGE C;
280+
281+
CREATE FUNCTION ivfflat_vector_sum_center(internal, internal) RETURNS internal
282+
AS 'MODULE_PATHNAME' LANGUAGE C;
283+
284+
CREATE FUNCTION ivfflat_bit_sum_center(internal, internal) RETURNS internal
285+
AS 'MODULE_PATHNAME' LANGUAGE C;
286+
287+
CREATE FUNCTION ivfflat_halfvec_sum_center(internal, internal) RETURNS internal
288+
AS 'MODULE_PATHNAME' LANGUAGE C;
277289

278290
CREATE FUNCTION hnsw_bit_max_dims(internal) RETURNS internal
279-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
291+
AS 'MODULE_PATHNAME' LANGUAGE C;
280292

281293
CREATE FUNCTION hnsw_halfvec_max_dims(internal) RETURNS internal
282-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
294+
AS 'MODULE_PATHNAME' LANGUAGE C;
283295

284296
CREATE FUNCTION hnsw_sparsevec_max_dims(internal) RETURNS internal
285-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
297+
AS 'MODULE_PATHNAME' LANGUAGE C;
286298

287299
CREATE FUNCTION hnsw_sparsevec_check_value(internal) RETURNS internal
288-
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
300+
AS 'MODULE_PATHNAME' LANGUAGE C;
289301

290302
-- vector opclasses
291303

@@ -368,7 +380,8 @@ CREATE OPERATOR CLASS bit_hamming_ops
368380
FUNCTION 1 hamming_distance(bit, bit),
369381
FUNCTION 3 hamming_distance(bit, bit),
370382
FUNCTION 6 ivfflat_bit_max_dims(internal),
371-
FUNCTION 7 ivfflat_bit_support(internal);
383+
FUNCTION 7 ivfflat_bit_update_center(internal, internal, internal),
384+
FUNCTION 8 ivfflat_bit_sum_center(internal, internal);
372385

373386
CREATE OPERATOR CLASS bit_hamming_ops
374387
FOR TYPE bit USING hnsw AS
@@ -652,7 +665,8 @@ CREATE OPERATOR CLASS halfvec_l2_ops
652665
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
653666
FUNCTION 3 l2_distance(halfvec, halfvec),
654667
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
655-
FUNCTION 7 ivfflat_halfvec_support(internal);
668+
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
669+
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
656670

657671
CREATE OPERATOR CLASS halfvec_ip_ops
658672
FOR TYPE halfvec USING ivfflat AS
@@ -662,7 +676,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops
662676
FUNCTION 4 l2_norm(halfvec),
663677
FUNCTION 5 l2_normalize(halfvec),
664678
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
665-
FUNCTION 7 ivfflat_halfvec_support(internal);
679+
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
680+
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
666681

667682
CREATE OPERATOR CLASS halfvec_cosine_ops
668683
FOR TYPE halfvec USING ivfflat AS
@@ -673,7 +688,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
673688
FUNCTION 4 l2_norm(halfvec),
674689
FUNCTION 5 l2_normalize(halfvec),
675690
FUNCTION 6 ivfflat_halfvec_max_dims(internal),
676-
FUNCTION 7 ivfflat_halfvec_support(internal);
691+
FUNCTION 7 ivfflat_halfvec_update_center(internal, internal, internal),
692+
FUNCTION 8 ivfflat_halfvec_sum_center(internal, internal);
677693

678694
CREATE OPERATOR CLASS halfvec_l2_ops
679695
FOR TYPE halfvec USING hnsw AS

src/halfvec.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,7 @@ halfvec_subvector(PG_FUNCTION_ARGS)
978978
/*
979979
* Internal helper to compare half vectors
980980
*/
981-
int
981+
static int
982982
halfvec_cmp_internal(HalfVector * a, HalfVector * b)
983983
{
984984
int dim = Min(a->dim, b->dim);

src/halfvec.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,5 @@ typedef struct HalfVector
4444
} HalfVector;
4545

4646
HalfVector *InitHalfVector(int dim);
47-
int halfvec_cmp_internal(HalfVector * a, HalfVector * b);
4847

4948
#endif

src/ivfbuild.c

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -319,27 +319,6 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
319319
}
320320
}
321321

322-
/*
323-
* Get type
324-
*/
325-
static IvfflatType
326-
IvfflatGetType(Relation index)
327-
{
328-
FmgrInfo *procinfo = IvfflatOptionalProcInfo(index, IVFFLAT_TYPE_SUPPORT_PROC);
329-
Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid;
330-
IvfflatType result;
331-
332-
if (procinfo == NULL)
333-
return IVFFLAT_TYPE_VECTOR;
334-
335-
result = (IvfflatType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid)));
336-
337-
if (result == IVFFLAT_TYPE_UNSUPPORTED)
338-
elog(ERROR, "type not supported for ivfflat index");
339-
340-
return result;
341-
}
342-
343322
/*
344323
* Get max dimensions
345324
*/
@@ -358,13 +337,14 @@ GetMaxDimensions(Relation index)
358337
* Get item size
359338
*/
360339
static Size
361-
GetItemSize(IvfflatType type, int dimensions)
340+
GetItemSize(int maxDimensions, int dimensions)
362341
{
363-
if (type == IVFFLAT_TYPE_VECTOR)
342+
/* TODO Improve */
343+
if (maxDimensions == IVFFLAT_MAX_DIM)
364344
return VECTOR_SIZE(dimensions);
365-
else if (type == IVFFLAT_TYPE_HALFVEC)
345+
else if (maxDimensions == IVFFLAT_MAX_DIM * 2)
366346
return HALFVEC_SIZE(dimensions);
367-
else if (type == IVFFLAT_TYPE_BIT)
347+
else if (maxDimensions == IVFFLAT_MAX_DIM * 32)
368348
return VARBITTOTALLEN(dimensions);
369349
else
370350
elog(ERROR, "Unsupported type");
@@ -381,7 +361,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
381361
buildstate->heap = heap;
382362
buildstate->index = index;
383363
buildstate->indexInfo = indexInfo;
384-
buildstate->type = IvfflatGetType(index);
385364

386365
buildstate->lists = IvfflatGetLists(index);
387366
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
@@ -421,7 +400,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
421400

422401
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
423402

424-
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(buildstate->type, buildstate->dimensions));
403+
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, GetItemSize(maxDimensions, buildstate->dimensions));
425404
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
426405

427406
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
@@ -491,7 +470,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
491470
}
492471

493472
/* Calculate centers */
494-
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers, buildstate->type));
473+
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
495474

496475
/* Free samples before we allocate more memory */
497476
VectorArrayFree(buildstate->samples);

src/ivfflat.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
188188
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
189189

190190
amroutine->amstrategies = 0;
191-
amroutine->amsupport = 7;
191+
amroutine->amsupport = 8;
192192
#if PG_VERSION_NUM >= 130000
193193
amroutine->amoptsprocnum = 0;
194194
#endif

src/ivfflat.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
#define IVFFLAT_KMEANS_NORM_PROC 4
3131
#define IVFFLAT_NORMALIZE_PROC 5
3232
#define IVFFLAT_MAX_DIMS_PROC 6
33-
#define IVFFLAT_TYPE_SUPPORT_PROC 7
33+
#define IVFFLAT_UPDATE_CENTER_PROC 7
34+
#define IVFFLAT_SUM_CENTER_PROC 8
3435

3536
#define IVFFLAT_VERSION 1
3637
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
@@ -46,14 +47,6 @@
4647
#define IVFFLAT_MAX_LISTS 32768
4748
#define IVFFLAT_DEFAULT_PROBES 1
4849

49-
typedef enum IvfflatType
50-
{
51-
IVFFLAT_TYPE_VECTOR,
52-
IVFFLAT_TYPE_HALFVEC,
53-
IVFFLAT_TYPE_BIT,
54-
IVFFLAT_TYPE_UNSUPPORTED
55-
} IvfflatType;
56-
5750
/* Build phases */
5851
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
5952
#define PROGRESS_IVFFLAT_PHASE_KMEANS 2
@@ -165,7 +158,6 @@ typedef struct IvfflatBuildState
165158
Relation heap;
166159
Relation index;
167160
IndexInfo *indexInfo;
168-
IvfflatType type;
169161

170162
/* Settings */
171163
int dimensions;
@@ -279,7 +271,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
279271
/* Methods */
280272
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
281273
void VectorArrayFree(VectorArray arr);
282-
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type);
274+
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
283275
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
284276
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
285277
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);

0 commit comments

Comments
 (0)