Skip to content

Commit e146f3c

Browse files
committed
Added avg for half vectors [skip ci]
1 parent 92d08bb commit e146f3c

7 files changed

Lines changed: 186 additions & 1 deletion

File tree

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,12 @@ l2_distance(halfvec, halfvec) → double precision | Euclidean distance | unrele
902902
quantize_binary(halfvec) → bit | quantize | unreleased
903903
subvector(halfvec, integer, integer) → halfvec | subvector | unreleased
904904

905+
### Halfvec Aggregate Functions
906+
907+
Function | Description | Added
908+
--- | --- | ---
909+
avg(halfvec) → halfvec | average | unreleased
910+
905911
### Bit Type
906912

907913
Each bit vector takes `dimensions / 8 + 8` bytes of storage. See the [Postgres docs](https://www.postgresql.org/docs/current/datatype-bit.html) for more info.

sql/vector--0.6.2--0.7.0.sql

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8
122122
CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8
123123
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
124124

125+
CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[]
126+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
127+
128+
CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
129+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
130+
131+
CREATE AGGREGATE avg(halfvec) (
132+
SFUNC = halfvec_accum,
133+
STYPE = double precision[],
134+
FINALFUNC = halfvec_avg,
135+
COMBINEFUNC = vector_combine,
136+
INITCOND = '{0}',
137+
PARALLEL = SAFE
138+
);
139+
125140
CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec
126141
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
127142

sql/vector.sql

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,23 @@ CREATE FUNCTION halfvec_negative_inner_product(halfvec, halfvec) RETURNS float8
417417
CREATE FUNCTION halfvec_spherical_distance(halfvec, halfvec) RETURNS float8
418418
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
419419

420+
CREATE FUNCTION halfvec_accum(double precision[], halfvec) RETURNS double precision[]
421+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
422+
423+
CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
424+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
425+
426+
-- halfvec aggregates
427+
428+
CREATE AGGREGATE avg(halfvec) (
429+
SFUNC = halfvec_accum,
430+
STYPE = double precision[],
431+
FINALFUNC = halfvec_avg,
432+
COMBINEFUNC = vector_combine,
433+
INITCOND = '{0}',
434+
PARALLEL = SAFE
435+
);
436+
420437
-- halfvec cast functions
421438

422439
CREATE FUNCTION halfvec(halfvec, integer, boolean) RETURNS halfvec

src/halfvec.c

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
#define TYPALIGN_INT 'i'
2424
#endif
2525

26+
#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1)
27+
#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1))
28+
2629
/*
2730
* Get a half from a message buffer
2831
*/
@@ -146,6 +149,20 @@ halfvec_isspace(char ch)
146149
return false;
147150
}
148151

152+
/*
153+
* Check state array
154+
*/
155+
static float8 *
156+
CheckStateArray(ArrayType *statearray, const char *caller)
157+
{
158+
if (ARR_NDIM(statearray) != 1 ||
159+
ARR_DIMS(statearray)[0] < 1 ||
160+
ARR_HASNULL(statearray) ||
161+
ARR_ELEMTYPE(statearray) != FLOAT8OID)
162+
elog(ERROR, "%s: expected state array", caller);
163+
return (float8 *) ARR_DATA_PTR(statearray);
164+
}
165+
149166
#if PG_VERSION_NUM < 120003
150167
static pg_noinline void
151168
float_overflow_error(void)
@@ -1016,3 +1033,98 @@ halfvec_cmp(PG_FUNCTION_ARGS)
10161033

10171034
PG_RETURN_INT32(halfvec_cmp_internal(a, b));
10181035
}
1036+
1037+
/*
1038+
* Accumulate half vectors
1039+
*/
1040+
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_accum);
1041+
Datum
1042+
halfvec_accum(PG_FUNCTION_ARGS)
1043+
{
1044+
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
1045+
HalfVector *newval = PG_GETARG_HALFVEC_P(1);
1046+
float8 *statevalues;
1047+
int16 dim;
1048+
bool newarr;
1049+
float8 n;
1050+
Datum *statedatums;
1051+
half *x = newval->x;
1052+
ArrayType *result;
1053+
1054+
/* Check array before using */
1055+
statevalues = CheckStateArray(statearray, "halfvec_accum");
1056+
dim = STATE_DIMS(statearray);
1057+
newarr = dim == 0;
1058+
1059+
if (newarr)
1060+
dim = newval->dim;
1061+
else
1062+
CheckExpectedDim(dim, newval->dim);
1063+
1064+
n = statevalues[0] + 1.0;
1065+
1066+
statedatums = CreateStateDatums(dim);
1067+
statedatums[0] = Float8GetDatum(n);
1068+
1069+
if (newarr)
1070+
{
1071+
for (int i = 0; i < dim; i++)
1072+
statedatums[i + 1] = Float8GetDatum((double) HalfToFloat4(x[i]));
1073+
}
1074+
else
1075+
{
1076+
for (int i = 0; i < dim; i++)
1077+
{
1078+
double v = statevalues[i + 1] + (double) HalfToFloat4(x[i]);
1079+
1080+
/* Check for overflow */
1081+
if (isinf(v))
1082+
float_overflow_error();
1083+
1084+
statedatums[i + 1] = Float8GetDatum(v);
1085+
}
1086+
}
1087+
1088+
/* Use float8 array like float4_accum */
1089+
result = construct_array(statedatums, dim + 1,
1090+
FLOAT8OID,
1091+
sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE);
1092+
1093+
pfree(statedatums);
1094+
1095+
PG_RETURN_ARRAYTYPE_P(result);
1096+
}
1097+
1098+
/*
1099+
* Average half vectors
1100+
*/
1101+
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_avg);
1102+
Datum
1103+
halfvec_avg(PG_FUNCTION_ARGS)
1104+
{
1105+
ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0);
1106+
float8 *statevalues;
1107+
float8 n;
1108+
uint16 dim;
1109+
HalfVector *result;
1110+
1111+
/* Check array before using */
1112+
statevalues = CheckStateArray(statearray, "halfvec_avg");
1113+
n = statevalues[0];
1114+
1115+
/* SQL defines AVG of no values to be NULL */
1116+
if (n == 0.0)
1117+
PG_RETURN_NULL();
1118+
1119+
/* Create half vector */
1120+
dim = STATE_DIMS(statearray);
1121+
CheckDim(dim);
1122+
result = InitHalfVector(dim);
1123+
for (int i = 0; i < dim; i++)
1124+
{
1125+
result->x[i] = Float4ToHalf(statevalues[i + 1] / n);
1126+
CheckElement(result->x[i]);
1127+
}
1128+
1129+
PG_RETURN_POINTER(result);
1130+
}

src/vector.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ vector_accum(PG_FUNCTION_ARGS)
11001100
}
11011101

11021102
/*
1103-
* Combine vectors
1103+
* Combine vectors or half vectors
11041104
*/
11051105
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine);
11061106
Datum

test/expected/halfvec_functions.out

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,31 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
320320
ERROR: halfvec must have at least 1 dimension
321321
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
322322
ERROR: halfvec must have at least 1 dimension
323+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
324+
avg
325+
-----------
326+
[2,3.5,5]
327+
(1 row)
328+
329+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;
330+
avg
331+
-----------
332+
[2,3.5,5]
333+
(1 row)
334+
335+
SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v;
336+
avg
337+
-----
338+
339+
(1 row)
340+
341+
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v;
342+
ERROR: expected 2 dimensions, not 1
343+
SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;
344+
avg
345+
---------
346+
[65504]
347+
(1 row)
348+
349+
SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;
350+
ERROR: halfvec cannot have more than 16000 dimensions

test/sql/halfvec_functions.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,10 @@ SELECT subvector('[1,2,3,4,5]'::halfvec, 3, 9);
6969
SELECT subvector('[1,2,3,4,5]'::halfvec, 1, 0);
7070
SELECT subvector('[1,2,3,4,5]'::halfvec, 3, -1);
7171
SELECT subvector('[1,2,3,4,5]'::halfvec, -1, 2);
72+
73+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]']) v;
74+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::halfvec, '[3,5,7]', NULL]) v;
75+
SELECT avg(v) FROM unnest(ARRAY[]::halfvec[]) v;
76+
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::halfvec, '[3]']) v;
77+
SELECT avg(v) FROM unnest(ARRAY['[65504]'::halfvec, '[65504]']) v;
78+
SELECT halfvec_avg(array_agg(n)) FROM generate_series(1, 16002) n;

0 commit comments

Comments
 (0)