Skip to content

Commit 5dec500

Browse files
committed
Reduced support functions for IVFFlat - #527
1 parent 1fdfff7 commit 5dec500

10 files changed

Lines changed: 36 additions & 41 deletions

File tree

sql/vector--0.6.2--0.7.0.sql

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
6363
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
6464
FUNCTION 1 hamming_distance(bit, bit),
6565
FUNCTION 3 hamming_distance(bit, bit),
66-
FUNCTION 6 ivfflat_bit_support(internal);
66+
FUNCTION 5 ivfflat_bit_support(internal);
6767

6868
CREATE OPERATOR CLASS bit_hamming_ops
6969
FOR TYPE bit USING hnsw AS
@@ -330,16 +330,15 @@ CREATE OPERATOR CLASS halfvec_l2_ops
330330
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
331331
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
332332
FUNCTION 3 l2_distance(halfvec, halfvec),
333-
FUNCTION 6 ivfflat_halfvec_support(internal);
333+
FUNCTION 5 ivfflat_halfvec_support(internal);
334334

335335
CREATE OPERATOR CLASS halfvec_ip_ops
336336
FOR TYPE halfvec USING ivfflat AS
337337
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
338338
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
339339
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
340340
FUNCTION 4 l2_norm(halfvec),
341-
FUNCTION 5 l2_normalize(halfvec),
342-
FUNCTION 6 ivfflat_halfvec_support(internal);
341+
FUNCTION 5 ivfflat_halfvec_support(internal);
343342

344343
CREATE OPERATOR CLASS halfvec_cosine_ops
345344
FOR TYPE halfvec USING ivfflat AS
@@ -348,8 +347,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
348347
FUNCTION 2 l2_norm(halfvec),
349348
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
350349
FUNCTION 4 l2_norm(halfvec),
351-
FUNCTION 5 l2_normalize(halfvec),
352-
FUNCTION 6 ivfflat_halfvec_support(internal);
350+
FUNCTION 5 ivfflat_halfvec_support(internal);
353351

354352
CREATE OPERATOR CLASS halfvec_l2_ops
355353
FOR TYPE halfvec USING hnsw AS

sql/vector.sql

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ CREATE OPERATOR CLASS bit_hamming_ops
358358
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
359359
FUNCTION 1 hamming_distance(bit, bit),
360360
FUNCTION 3 hamming_distance(bit, bit),
361-
FUNCTION 6 ivfflat_bit_support(internal);
361+
FUNCTION 5 ivfflat_bit_support(internal);
362362

363363
CREATE OPERATOR CLASS bit_hamming_ops
364364
FOR TYPE bit USING hnsw AS
@@ -641,16 +641,15 @@ CREATE OPERATOR CLASS halfvec_l2_ops
641641
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
642642
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
643643
FUNCTION 3 l2_distance(halfvec, halfvec),
644-
FUNCTION 6 ivfflat_halfvec_support(internal);
644+
FUNCTION 5 ivfflat_halfvec_support(internal);
645645

646646
CREATE OPERATOR CLASS halfvec_ip_ops
647647
FOR TYPE halfvec USING ivfflat AS
648648
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
649649
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
650650
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
651651
FUNCTION 4 l2_norm(halfvec),
652-
FUNCTION 5 l2_normalize(halfvec),
653-
FUNCTION 6 ivfflat_halfvec_support(internal);
652+
FUNCTION 5 ivfflat_halfvec_support(internal);
654653

655654
CREATE OPERATOR CLASS halfvec_cosine_ops
656655
FOR TYPE halfvec USING ivfflat AS
@@ -659,8 +658,7 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
659658
FUNCTION 2 l2_norm(halfvec),
660659
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
661660
FUNCTION 4 l2_norm(halfvec),
662-
FUNCTION 5 l2_normalize(halfvec),
663-
FUNCTION 6 ivfflat_halfvec_support(internal);
661+
FUNCTION 5 ivfflat_halfvec_support(internal);
664662

665663
CREATE OPERATOR CLASS halfvec_l2_ops
666664
FOR TYPE halfvec USING hnsw AS

src/ivfbuild.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
6363
if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value))
6464
return;
6565

66-
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
66+
value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
6767
}
6868

6969
if (samples->length < targsamples)
@@ -161,7 +161,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
161161
if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
162162
return;
163163

164-
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
164+
value = IvfflatNormValue(buildstate->typeInfo, buildstate->collation, value);
165165
}
166166

167167
/* Find the list that minimizes the distance */
@@ -351,7 +351,6 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
351351
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
352352
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
353353
buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
354-
buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
355354
buildstate->collation = index->rd_indcollation[0];
356355

357356
/* Require more than one dimension for spherical k-means */

src/ivfflat.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
188188
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
189189

190190
amroutine->amstrategies = 0;
191-
amroutine->amsupport = 6;
191+
amroutine->amsupport = 5;
192192
#if PG_VERSION_NUM >= 130000
193193
amroutine->amoptsprocnum = 0;
194194
#endif

src/ivfflat.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
#define IVFFLAT_NORM_PROC 2
2929
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
3030
#define IVFFLAT_KMEANS_NORM_PROC 4
31-
#define IVFFLAT_NORMALIZE_PROC 5
32-
#define IVFFLAT_TYPE_INFO_PROC 6
31+
#define IVFFLAT_TYPE_INFO_PROC 5
3332

3433
#define IVFFLAT_VERSION 1
3534
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
@@ -153,6 +152,7 @@ typedef struct IvfflatLeader
153152
typedef struct IvfflatTypeInfo
154153
{
155154
int maxDimensions;
155+
Datum (*normalize) (PG_FUNCTION_ARGS);
156156
void (*updateCenter) (Pointer v, int dimensions, float *x);
157157
void (*sumCenter) (Pointer v, float *x);
158158
} IvfflatTypeInfo;
@@ -177,7 +177,6 @@ typedef struct IvfflatBuildState
177177
FmgrInfo *procinfo;
178178
FmgrInfo *normprocinfo;
179179
FmgrInfo *kmeansnormprocinfo;
180-
FmgrInfo *normalizeprocinfo;
181180
Oid collation;
182181

183182
/* Variables */
@@ -245,6 +244,7 @@ typedef struct IvfflatScanList
245244

246245
typedef struct IvfflatScanOpaqueData
247246
{
247+
const IvfflatTypeInfo *typeInfo;
248248
int probes;
249249
int dimensions;
250250
bool first;
@@ -258,7 +258,6 @@ typedef struct IvfflatScanOpaqueData
258258
/* Support functions */
259259
FmgrInfo *procinfo;
260260
FmgrInfo *normprocinfo;
261-
FmgrInfo *normalizeprocinfo;
262261
Oid collation;
263262
Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2);
264263

@@ -279,7 +278,7 @@ VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
279278
void VectorArrayFree(VectorArray arr);
280279
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfflatTypeInfo * typeInfo);
281280
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
282-
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
281+
Datum IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value);
283282
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
284283
int IvfflatGetLists(Relation index);
285284
void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions);

src/ivfinsert.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ FindInsertPage(Relation index, Datum *values, BlockNumber *insertPage, ListInfo
6767
static void
6868
InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel)
6969
{
70+
const IvfflatTypeInfo *typeInfo = IvfflatGetTypeInfo(index);
7071
IndexTuple itup;
7172
Datum value;
7273
FmgrInfo *normprocinfo;
@@ -90,7 +91,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R
9091
if (!IvfflatCheckNorm(normprocinfo, collation, value))
9192
return;
9293

93-
value = IvfflatNormValue(IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC), collation, value);
94+
value = IvfflatNormValue(typeInfo, collation, value);
9495
}
9596

9697
/* Find the insert page - sets the page and list info */

src/ivfkmeans.c

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
9292
* Norm centers
9393
*/
9494
static void
95-
NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
95+
NormCenters(const IvfflatTypeInfo * typeInfo, Oid collation, VectorArray centers)
9696
{
9797
MemoryContext normCtx = AllocSetContextCreate(CurrentMemoryContext,
9898
"Ivfflat norm temporary context",
@@ -102,7 +102,7 @@ NormCenters(FmgrInfo *normalizeprocinfo, Oid collation, VectorArray centers)
102102
for (int j = 0; j < centers->length; j++)
103103
{
104104
Datum center = PointerGetDatum(VectorArrayGet(centers, j));
105-
Datum newCenter = IvfflatNormValue(normalizeprocinfo, collation, center);
105+
Datum newCenter = IvfflatNormValue(typeInfo, collation, center);
106106
Size size = VARSIZE_ANY(DatumGetPointer(newCenter));
107107

108108
if (size > centers->itemsize)
@@ -123,9 +123,8 @@ static void
123123
RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeInfo)
124124
{
125125
int dimensions = centers->dim;
126-
Oid collation = index->rd_indcollation[0];
127126
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
128-
FmgrInfo *normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
127+
Oid collation = index->rd_indcollation[0];
129128
float *x = (float *) palloc(sizeof(float) * dimensions);
130129

131130
/* Fill with random data */
@@ -142,7 +141,7 @@ RandomCenters(Relation index, VectorArray centers, const IvfflatTypeInfo * typeI
142141
}
143142

144143
if (normprocinfo != NULL)
145-
NormCenters(normalizeprocinfo, collation, centers);
144+
NormCenters(typeInfo, collation, centers);
146145

147146
pfree(x);
148147
}
@@ -196,7 +195,7 @@ UpdateCenters(float *agg, VectorArray centers, const IvfflatTypeInfo * typeInfo)
196195
* Compute new centers
197196
*/
198197
static void
199-
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, FmgrInfo *normalizeprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
198+
ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *centerCounts, int *closestCenters, FmgrInfo *normprocinfo, Oid collation, const IvfflatTypeInfo * typeInfo)
200199
{
201200
int dimensions = newCenters->dim;
202201
int numCenters = newCenters->length;
@@ -251,7 +250,7 @@ ComputeNewCenters(VectorArray samples, float *agg, VectorArray newCenters, int *
251250

252251
/* Normalize if needed */
253252
if (normprocinfo != NULL)
254-
NormCenters(normalizeprocinfo, collation, newCenters);
253+
NormCenters(typeInfo, collation, newCenters);
255254
}
256255

257256
/*
@@ -267,7 +266,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
267266
{
268267
FmgrInfo *procinfo;
269268
FmgrInfo *normprocinfo;
270-
FmgrInfo *normalizeprocinfo;
271269
Oid collation;
272270
int dimensions = centers->dim;
273271
int numCenters = centers->maxlen;
@@ -315,7 +313,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
315313
/* Set support functions */
316314
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
317315
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
318-
normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
319316
collation = index->rd_indcollation[0];
320317

321318
/* Use memory context */
@@ -477,7 +474,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const Ivff
477474
}
478475

479476
/* Step 4: For each center c, let m(c) be mean of all points assigned */
480-
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, normalizeprocinfo, collation, typeInfo);
477+
ComputeNewCenters(samples, agg, newCenters, centerCounts, closestCenters, normprocinfo, collation, typeInfo);
481478

482479
/* Step 5 */
483480
for (int j = 0; j < numCenters; j++)

src/ivfscan.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ GetScanValue(IndexScanDesc scan)
209209
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
210210
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
211211

212-
/* Check normprocinfo since normalizeprocinfo not set for vector */
212+
/* Normalize if needed */
213213
if (so->normprocinfo != NULL)
214-
value = IvfflatNormValue(so->normalizeprocinfo, so->collation, value);
214+
value = IvfflatNormValue(so->typeInfo, so->collation, value);
215215
}
216216

217217
return value;
@@ -242,14 +242,14 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
242242
probes = lists;
243243

244244
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList));
245+
so->typeInfo = IvfflatGetTypeInfo(index);
245246
so->first = true;
246247
so->probes = probes;
247248
so->dimensions = dimensions;
248249

249250
/* Set support functions */
250251
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
251252
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
252-
so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
253253
so->collation = index->rd_indcollation[0];
254254

255255
/* Create tuple description for sorting */

src/ivfutils.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,9 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum)
6868
* Normalize value
6969
*/
7070
Datum
71-
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value)
71+
IvfflatNormValue(const IvfflatTypeInfo * typeInfo, Oid collation, Datum value)
7272
{
73-
if (procinfo == NULL)
74-
return DirectFunctionCall1(l2_normalize, value);
75-
76-
return FunctionCall1Coll(procinfo, collation, value);
73+
return DirectFunctionCall1Coll(typeInfo->normalize, collation, value);
7774
}
7875

7976
/*
@@ -228,6 +225,10 @@ IvfflatUpdateList(Relation index, ListInfo listInfo,
228225
}
229226
}
230227

228+
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
229+
PGDLLEXPORT Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
230+
PGDLLEXPORT Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
231+
231232
static void
232233
VectorUpdateCenter(Pointer v, int dimensions, float *x)
233234
{
@@ -307,6 +308,7 @@ IvfflatGetTypeInfo(Relation index)
307308
{
308309
static const IvfflatTypeInfo typeInfo = {
309310
.maxDimensions = IVFFLAT_MAX_DIM,
311+
.normalize = l2_normalize,
310312
.updateCenter = VectorUpdateCenter,
311313
.sumCenter = VectorSumCenter
312314
};
@@ -323,6 +325,7 @@ ivfflat_halfvec_support(PG_FUNCTION_ARGS)
323325
{
324326
static const IvfflatTypeInfo typeInfo = {
325327
.maxDimensions = IVFFLAT_MAX_DIM * 2,
328+
.normalize = halfvec_l2_normalize,
326329
.updateCenter = HalfvecUpdateCenter,
327330
.sumCenter = HalfvecSumCenter
328331
};
@@ -336,6 +339,7 @@ ivfflat_bit_support(PG_FUNCTION_ARGS)
336339
{
337340
static const IvfflatTypeInfo typeInfo = {
338341
.maxDimensions = IVFFLAT_MAX_DIM * 32,
342+
.normalize = NULL,
339343
.updateCenter = BitUpdateCenter,
340344
.sumCenter = BitSumCenter
341345
};

src/vector.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ typedef struct Vector
2121
Vector *InitVector(int dim);
2222
void PrintVector(char *msg, Vector * vector);
2323
int vector_cmp_internal(Vector * a, Vector * b);
24-
PGDLLEXPORT Datum l2_normalize(PG_FUNCTION_ARGS);
2524

2625
#endif

0 commit comments

Comments
 (0)