Skip to content

Commit 127ecdd

Browse files
committed
Added l2_normalize function for sparsevec
1 parent 10dacfd commit 127ecdd

8 files changed

Lines changed: 89 additions & 14 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,7 @@ inner_product(sparsevec, sparsevec) → double precision | inner product | unrel
952952
l1_distance(sparsevec, sparsevec) → double precision | taxicab distance | unreleased
953953
l2_distance(sparsevec, sparsevec) → double precision | Euclidean distance | unreleased
954954
l2_norm(sparsevec) → double precision | Euclidean norm | unreleased
955+
l2_normalize(sparsevec) → sparsevec | Normalize with Euclidean norm | unreleased
955956

956957
## Installation Notes - Linux and Mac
957958

sql/vector--0.6.2--0.7.0.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8
364364
CREATE FUNCTION l2_norm(sparsevec) RETURNS float8
365365
AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
366366

367+
CREATE FUNCTION l2_normalize(sparsevec) RETURNS sparsevec
368+
AS 'MODULE_PATHNAME', 'sparsevec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
369+
367370
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool
368371
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
369372

sql/vector.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,9 @@ CREATE FUNCTION l1_distance(sparsevec, sparsevec) RETURNS float8
673673
CREATE FUNCTION l2_norm(sparsevec) RETURNS float8
674674
AS 'MODULE_PATHNAME', 'sparsevec_l2_norm' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
675675

676+
CREATE FUNCTION l2_normalize(sparsevec) RETURNS sparsevec
677+
AS 'MODULE_PATHNAME', 'sparsevec_l2_normalize' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
678+
676679
-- sparsevec private functions
677680

678681
CREATE FUNCTION sparsevec_lt(sparsevec, sparsevec) RETURNS bool

src/hnswutils.c

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,7 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
213213
else if (type == HNSW_TYPE_HALFVEC)
214214
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
215215
else if (type == HNSW_TYPE_SPARSEVEC)
216-
{
217-
SparseVector *v = DatumGetSparseVector(*value);
218-
SparseVector *result = InitSparseVector(v->dim, v->nnz);
219-
float *vx = SPARSEVEC_VALUES(v);
220-
float *rx = SPARSEVEC_VALUES(result);
221-
222-
for (int i = 0; i < v->nnz; i++)
223-
{
224-
result->indices[i] = v->indices[i];
225-
rx[i] = vx[i] / norm;
226-
}
227-
228-
*value = PointerGetDatum(result);
229-
}
216+
*value = DirectFunctionCall1(sparsevec_l2_normalize, *value);
230217
else
231218
elog(ERROR, "Unsupported type");
232219

src/sparsevec.c

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,48 @@ sparsevec_l2_norm(PG_FUNCTION_ARGS)
848848
PG_RETURN_FLOAT8(sqrt(norm));
849849
}
850850

851+
/*
852+
* Normalize a sparse vector with the L2 norm
853+
*/
854+
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_normalize);
855+
Datum
856+
sparsevec_l2_normalize(PG_FUNCTION_ARGS)
857+
{
858+
SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
859+
float *ax = SPARSEVEC_VALUES(a);
860+
double norm = 0;
861+
SparseVector *result;
862+
float *rx;
863+
864+
result = InitSparseVector(a->dim, a->nnz);
865+
rx = SPARSEVEC_VALUES(result);
866+
867+
/* Auto-vectorized */
868+
for (int i = 0; i < a->nnz; i++)
869+
norm += (double) ax[i] * (double) ax[i];
870+
871+
norm = sqrt(norm);
872+
873+
/* Return zero vector for zero norm */
874+
if (norm > 0)
875+
{
876+
for (int i = 0; i < a->nnz; i++)
877+
{
878+
result->indices[i] = a->indices[i];
879+
rx[i] = ax[i] / norm;
880+
}
881+
882+
/* Check for overflow */
883+
for (int i = 0; i < a->nnz; i++)
884+
{
885+
if (isinf(rx[i]))
886+
float_overflow_error();
887+
}
888+
}
889+
890+
PG_RETURN_POINTER(result);
891+
}
892+
851893
/*
852894
* Internal helper to compare sparse vectors
853895
*/

src/sparsevec.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef SPARSEVEC_H
22
#define SPARSEVEC_H
33

4+
#include "fmgr.h"
5+
46
#define SPARSEVEC_MAX_DIM 100000
57
#define SPARSEVEC_MAX_NNZ 16000
68

@@ -21,5 +23,6 @@ typedef struct SparseVector
2123
} SparseVector;
2224

2325
SparseVector *InitSparseVector(int dim, int nnz);
26+
Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS);
2427

2528
#endif

test/expected/sparsevec_functions.out

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,33 @@ SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9');
292292
45
293293
(1 row)
294294

295+
SELECT l2_normalize('{1:3,2:4}/2'::sparsevec);
296+
l2_normalize
297+
-----------------
298+
{1:0.6,2:0.8}/2
299+
(1 row)
300+
301+
SELECT l2_normalize('{1:3}/2'::sparsevec);
302+
l2_normalize
303+
--------------
304+
{1:1}/2
305+
(1 row)
306+
307+
SELECT l2_normalize('{2:0.1}/2'::sparsevec);
308+
l2_normalize
309+
--------------
310+
{2:1}/2
311+
(1 row)
312+
313+
SELECT l2_normalize('{}/2'::sparsevec);
314+
l2_normalize
315+
--------------
316+
{}/2
317+
(1 row)
318+
319+
SELECT l2_normalize('{1:3e38}/1'::sparsevec);
320+
l2_normalize
321+
--------------
322+
{1:1}/1
323+
(1 row)
324+

test/sql/sparsevec_functions.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,9 @@ SELECT l1_distance('{1:1,2:2}/2'::sparsevec, '{1:3}/1');
5555
SELECT l1_distance('{1:3e38}/1'::sparsevec, '{1:-3e38}/1');
5656
SELECT l1_distance('{1:1,3:3,5:5,7:7}/8'::sparsevec, '{2:2,4:4,6:6,8:8}/8');
5757
SELECT l1_distance('{1:1,3:3,5:5,7:7,9:9}/9'::sparsevec, '{2:2,4:4,6:6,8:8}/9');
58+
59+
SELECT l2_normalize('{1:3,2:4}/2'::sparsevec);
60+
SELECT l2_normalize('{1:3}/2'::sparsevec);
61+
SELECT l2_normalize('{2:0.1}/2'::sparsevec);
62+
SELECT l2_normalize('{}/2'::sparsevec);
63+
SELECT l2_normalize('{1:3e38}/1'::sparsevec);

0 commit comments

Comments
 (0)