Skip to content

Commit 38b223b

Browse files
committed
Added concatenate operator for vectors
1 parent 4f6c485 commit 38b223b

10 files changed

Lines changed: 101 additions & 0 deletions

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Added `jaccard_distance` function
99
- Added `l2_normalize` function
1010
- Added `subvector` function
11+
- Added concatenate operator for vectors
1112
- Added CPU dispatching for distance functions on Linux x86-64
1213
- Updated comparison operators to support vectors with different dimensions
1314

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@ Operator | Description | Added
850850
\+ | element-wise addition |
851851
\- | element-wise subtraction |
852852
\* | element-wise multiplication | 0.5.0
853+
\|\| | concatenate | unreleased
853854
<-> | Euclidean distance |
854855
<#> | negative inner product |
855856
<=> | cosine distance |
@@ -886,6 +887,7 @@ Operator | Description | Added
886887
\+ | element-wise addition | unreleased
887888
\- | element-wise subtraction | unreleased
888889
\* | element-wise multiplication | unreleased
890+
\|\| | concatenate | unreleased
889891
<-> | Euclidean distance | unreleased
890892
<#> | negative inner product | unreleased
891893
<=> | cosine distance | unreleased

sql/vector--0.6.2--0.7.0.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ CREATE FUNCTION binary_quantize(vector) RETURNS bit
1010
CREATE FUNCTION subvector(vector, int, int) RETURNS vector
1111
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
1212

13+
CREATE FUNCTION vector_concat(vector, vector) RETURNS vector
14+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
15+
16+
CREATE OPERATOR || (
17+
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_concat
18+
);
19+
1320
CREATE FUNCTION hamming_distance(bit, bit) RETURNS float8
1421
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
1522

@@ -98,6 +105,9 @@ CREATE FUNCTION halfvec_sub(halfvec, halfvec) RETURNS halfvec
98105
CREATE FUNCTION halfvec_mul(halfvec, halfvec) RETURNS halfvec
99106
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
100107

108+
CREATE FUNCTION halfvec_concat(halfvec, halfvec) RETURNS halfvec
109+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
110+
101111
CREATE FUNCTION halfvec_lt(halfvec, halfvec) RETURNS bool
102112
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
103113

@@ -227,6 +237,10 @@ CREATE OPERATOR * (
227237
COMMUTATOR = *
228238
);
229239

240+
CREATE OPERATOR || (
241+
LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_concat
242+
);
243+
230244
CREATE OPERATOR < (
231245
LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_lt,
232246
COMMUTATOR = > , NEGATOR = >= ,

sql/vector.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ CREATE FUNCTION vector_sub(vector, vector) RETURNS vector
6969
CREATE FUNCTION vector_mul(vector, vector) RETURNS vector
7070
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
7171

72+
CREATE FUNCTION vector_concat(vector, vector) RETURNS vector
73+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
74+
7275
CREATE FUNCTION vector_lt(vector, vector) RETURNS bool
7376
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
7477

@@ -197,6 +200,10 @@ CREATE OPERATOR * (
197200
COMMUTATOR = *
198201
);
199202

203+
CREATE OPERATOR || (
204+
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_concat
205+
);
206+
200207
CREATE OPERATOR < (
201208
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_lt,
202209
COMMUTATOR = > , NEGATOR = >= ,
@@ -393,6 +400,9 @@ CREATE FUNCTION halfvec_sub(halfvec, halfvec) RETURNS halfvec
393400
CREATE FUNCTION halfvec_mul(halfvec, halfvec) RETURNS halfvec
394401
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
395402

403+
CREATE FUNCTION halfvec_concat(halfvec, halfvec) RETURNS halfvec
404+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
405+
396406
CREATE FUNCTION halfvec_lt(halfvec, halfvec) RETURNS bool
397407
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
398408

@@ -530,6 +540,10 @@ CREATE OPERATOR * (
530540
COMMUTATOR = *
531541
);
532542

543+
CREATE OPERATOR || (
544+
LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_concat
545+
);
546+
533547
CREATE OPERATOR < (
534548
LEFTARG = halfvec, RIGHTARG = halfvec, PROCEDURE = halfvec_lt,
535549
COMMUTATOR = > , NEGATOR = >= ,

src/halfvec.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,30 @@ halfvec_mul(PG_FUNCTION_ARGS)
905905
PG_RETURN_POINTER(result);
906906
}
907907

908+
/*
909+
* Concatenate half vectors
910+
*/
911+
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_concat);
912+
Datum
913+
halfvec_concat(PG_FUNCTION_ARGS)
914+
{
915+
HalfVector *a = PG_GETARG_HALFVEC_P(0);
916+
HalfVector *b = PG_GETARG_HALFVEC_P(1);
917+
HalfVector *result;
918+
int dim = a->dim + b->dim;
919+
920+
CheckDim(dim);
921+
result = InitHalfVector(dim);
922+
923+
for (int i = 0; i < a->dim; i++)
924+
result->x[i] = a->x[i];
925+
926+
for (int i = 0; i < b->dim; i++)
927+
result->x[i + a->dim] = b->x[i];
928+
929+
PG_RETURN_POINTER(result);
930+
}
931+
908932
/*
909933
* Quantize a half vector
910934
*/

src/vector.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,30 @@ vector_mul(PG_FUNCTION_ARGS)
916916
PG_RETURN_POINTER(result);
917917
}
918918

919+
/*
920+
* Concatenate vectors
921+
*/
922+
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_concat);
923+
Datum
924+
vector_concat(PG_FUNCTION_ARGS)
925+
{
926+
Vector *a = PG_GETARG_VECTOR_P(0);
927+
Vector *b = PG_GETARG_VECTOR_P(1);
928+
Vector *result;
929+
int dim = a->dim + b->dim;
930+
931+
CheckDim(dim);
932+
result = InitVector(dim);
933+
934+
for (int i = 0; i < a->dim; i++)
935+
result->x[i] = a->x[i];
936+
937+
for (int i = 0; i < b->dim; i++)
938+
result->x[i + a->dim] = b->x[i];
939+
940+
PG_RETURN_POINTER(result);
941+
}
942+
919943
/*
920944
* Quantize a vector
921945
*/

test/expected/halfvec_functions.out

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ SELECT '[65519]'::halfvec * '[65519]';
2424
ERROR: value out of range: overflow
2525
SELECT '[1e-7]'::halfvec * '[1e-7]';
2626
ERROR: value out of range: underflow
27+
SELECT '[1,2,3]'::halfvec || '[4,5]'::halfvec;
28+
?column?
29+
-------------
30+
[1,2,3,4,5]
31+
(1 row)
32+
33+
SELECT array_fill(0, ARRAY[16000])::halfvec || '[1]'::halfvec;
34+
ERROR: halfvec cannot have more than 16000 dimensions
2735
SELECT '[1,2,3]'::halfvec < '[1,2,3]';
2836
?column?
2937
----------

test/expected/vector_functions.out

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ SELECT '[1e37]'::vector * '[1e37]';
2424
ERROR: value out of range: overflow
2525
SELECT '[1e-37]'::vector * '[1e-37]';
2626
ERROR: value out of range: underflow
27+
SELECT '[1,2,3]'::vector || '[4,5]'::vector;
28+
?column?
29+
-------------
30+
[1,2,3,4,5]
31+
(1 row)
32+
33+
SELECT array_fill(0, ARRAY[16000])::vector || '[1]'::vector;
34+
ERROR: vector cannot have more than 16000 dimensions
2735
SELECT '[1,2,3]'::vector < '[1,2,3]';
2836
?column?
2937
----------

test/sql/halfvec_functions.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ SELECT '[1,2,3]'::halfvec * '[4,5,6]';
66
SELECT '[65519]'::halfvec * '[65519]';
77
SELECT '[1e-7]'::halfvec * '[1e-7]';
88

9+
SELECT '[1,2,3]'::halfvec || '[4,5]'::halfvec;
10+
SELECT array_fill(0, ARRAY[16000])::halfvec || '[1]'::halfvec;
11+
912
SELECT '[1,2,3]'::halfvec < '[1,2,3]';
1013
SELECT '[1,2,3]'::halfvec < '[1,2]';
1114
SELECT '[1,2,3]'::halfvec <= '[1,2,3]';

test/sql/vector_functions.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ SELECT '[1,2,3]'::vector * '[4,5,6]';
66
SELECT '[1e37]'::vector * '[1e37]';
77
SELECT '[1e-37]'::vector * '[1e-37]';
88

9+
SELECT '[1,2,3]'::vector || '[4,5]'::vector;
10+
SELECT array_fill(0, ARRAY[16000])::vector || '[1]'::vector;
11+
912
SELECT '[1,2,3]'::vector < '[1,2,3]';
1013
SELECT '[1,2,3]'::vector < '[1,2]';
1114
SELECT '[1,2,3]'::vector <= '[1,2,3]';

0 commit comments

Comments
 (0)