|
23 | 23 | #define TYPALIGN_INT 'i' |
24 | 24 | #endif |
25 | 25 |
|
| 26 | +#define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1) |
| 27 | +#define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1)) |
| 28 | + |
26 | 29 | /* |
27 | 30 | * Get a half from a message buffer |
28 | 31 | */ |
@@ -146,6 +149,20 @@ halfvec_isspace(char ch) |
146 | 149 | return false; |
147 | 150 | } |
148 | 151 |
|
| 152 | +/* |
| 153 | + * Check state array |
| 154 | + */ |
| 155 | +static float8 * |
| 156 | +CheckStateArray(ArrayType *statearray, const char *caller) |
| 157 | +{ |
| 158 | + if (ARR_NDIM(statearray) != 1 || |
| 159 | + ARR_DIMS(statearray)[0] < 1 || |
| 160 | + ARR_HASNULL(statearray) || |
| 161 | + ARR_ELEMTYPE(statearray) != FLOAT8OID) |
| 162 | + elog(ERROR, "%s: expected state array", caller); |
| 163 | + return (float8 *) ARR_DATA_PTR(statearray); |
| 164 | +} |
| 165 | + |
149 | 166 | #if PG_VERSION_NUM < 120003 |
150 | 167 | static pg_noinline void |
151 | 168 | float_overflow_error(void) |
@@ -1016,3 +1033,98 @@ halfvec_cmp(PG_FUNCTION_ARGS) |
1016 | 1033 |
|
1017 | 1034 | PG_RETURN_INT32(halfvec_cmp_internal(a, b)); |
1018 | 1035 | } |
| 1036 | + |
| 1037 | +/* |
| 1038 | + * Accumulate half vectors |
| 1039 | + */ |
| 1040 | +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_accum); |
| 1041 | +Datum |
| 1042 | +halfvec_accum(PG_FUNCTION_ARGS) |
| 1043 | +{ |
| 1044 | + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); |
| 1045 | + HalfVector *newval = PG_GETARG_HALFVEC_P(1); |
| 1046 | + float8 *statevalues; |
| 1047 | + int16 dim; |
| 1048 | + bool newarr; |
| 1049 | + float8 n; |
| 1050 | + Datum *statedatums; |
| 1051 | + half *x = newval->x; |
| 1052 | + ArrayType *result; |
| 1053 | + |
| 1054 | + /* Check array before using */ |
| 1055 | + statevalues = CheckStateArray(statearray, "halfvec_accum"); |
| 1056 | + dim = STATE_DIMS(statearray); |
| 1057 | + newarr = dim == 0; |
| 1058 | + |
| 1059 | + if (newarr) |
| 1060 | + dim = newval->dim; |
| 1061 | + else |
| 1062 | + CheckExpectedDim(dim, newval->dim); |
| 1063 | + |
| 1064 | + n = statevalues[0] + 1.0; |
| 1065 | + |
| 1066 | + statedatums = CreateStateDatums(dim); |
| 1067 | + statedatums[0] = Float8GetDatum(n); |
| 1068 | + |
| 1069 | + if (newarr) |
| 1070 | + { |
| 1071 | + for (int i = 0; i < dim; i++) |
| 1072 | + statedatums[i + 1] = Float8GetDatum((double) HalfToFloat4(x[i])); |
| 1073 | + } |
| 1074 | + else |
| 1075 | + { |
| 1076 | + for (int i = 0; i < dim; i++) |
| 1077 | + { |
| 1078 | + double v = statevalues[i + 1] + (double) HalfToFloat4(x[i]); |
| 1079 | + |
| 1080 | + /* Check for overflow */ |
| 1081 | + if (isinf(v)) |
| 1082 | + float_overflow_error(); |
| 1083 | + |
| 1084 | + statedatums[i + 1] = Float8GetDatum(v); |
| 1085 | + } |
| 1086 | + } |
| 1087 | + |
| 1088 | + /* Use float8 array like float4_accum */ |
| 1089 | + result = construct_array(statedatums, dim + 1, |
| 1090 | + FLOAT8OID, |
| 1091 | + sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE); |
| 1092 | + |
| 1093 | + pfree(statedatums); |
| 1094 | + |
| 1095 | + PG_RETURN_ARRAYTYPE_P(result); |
| 1096 | +} |
| 1097 | + |
| 1098 | +/* |
| 1099 | + * Average half vectors |
| 1100 | + */ |
| 1101 | +PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_avg); |
| 1102 | +Datum |
| 1103 | +halfvec_avg(PG_FUNCTION_ARGS) |
| 1104 | +{ |
| 1105 | + ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); |
| 1106 | + float8 *statevalues; |
| 1107 | + float8 n; |
| 1108 | + uint16 dim; |
| 1109 | + HalfVector *result; |
| 1110 | + |
| 1111 | + /* Check array before using */ |
| 1112 | + statevalues = CheckStateArray(statearray, "halfvec_avg"); |
| 1113 | + n = statevalues[0]; |
| 1114 | + |
| 1115 | + /* SQL defines AVG of no values to be NULL */ |
| 1116 | + if (n == 0.0) |
| 1117 | + PG_RETURN_NULL(); |
| 1118 | + |
| 1119 | + /* Create half vector */ |
| 1120 | + dim = STATE_DIMS(statearray); |
| 1121 | + CheckDim(dim); |
| 1122 | + result = InitHalfVector(dim); |
| 1123 | + for (int i = 0; i < dim; i++) |
| 1124 | + { |
| 1125 | + result->x[i] = Float4ToHalf(statevalues[i + 1] / n); |
| 1126 | + CheckElement(result->x[i]); |
| 1127 | + } |
| 1128 | + |
| 1129 | + PG_RETURN_POINTER(result); |
| 1130 | +} |
0 commit comments