Skip to content

Commit 10d33d3

Browse files
author
Bhargav Dodla
committed
fix: Handles null values in data during GO Feature retrieval
Signed-off-by: Bhargav Dodla <bdodla@expediagroup.com>
1 parent 2478831 commit 10d33d3

File tree

4 files changed

+164
-89
lines changed

4 files changed

+164
-89
lines changed

go/types/typeconversion.go

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import (
1111
)
1212

1313
func ProtoTypeToArrowType(sample *types.Value) (arrow.DataType, error) {
14+
if sample.Val == nil {
15+
return nil, nil
16+
}
1417
switch sample.Val.(type) {
1518
case *types.Value_BytesVal:
1619
return arrow.BinaryTypes.Binary, nil
@@ -91,81 +94,71 @@ func ValueTypeEnumToArrowType(t types.ValueType_Enum) (arrow.DataType, error) {
9194
}
9295

9396
func CopyProtoValuesToArrowArray(builder array.Builder, values []*types.Value) error {
94-
switch fieldBuilder := builder.(type) {
95-
case *array.BooleanBuilder:
96-
for _, v := range values {
97-
fieldBuilder.Append(v.GetBoolVal())
98-
}
99-
case *array.BinaryBuilder:
100-
for _, v := range values {
101-
fieldBuilder.Append(v.GetBytesVal())
102-
}
103-
case *array.StringBuilder:
104-
for _, v := range values {
105-
fieldBuilder.Append(v.GetStringVal())
106-
}
107-
case *array.Int32Builder:
108-
for _, v := range values {
109-
fieldBuilder.Append(v.GetInt32Val())
110-
}
111-
case *array.Int64Builder:
112-
for _, v := range values {
113-
fieldBuilder.Append(v.GetInt64Val())
114-
}
115-
case *array.Float32Builder:
116-
for _, v := range values {
117-
fieldBuilder.Append(v.GetFloatVal())
97+
for _, value := range values {
98+
if value == nil || value.Val == nil {
99+
builder.AppendNull()
100+
continue
118101
}
119-
case *array.Float64Builder:
120-
for _, v := range values {
121-
fieldBuilder.Append(v.GetDoubleVal())
122-
}
123-
case *array.TimestampBuilder:
124-
for _, v := range values {
125-
fieldBuilder.Append(arrow.Timestamp(v.GetUnixTimestampVal()))
126-
}
127-
case *array.ListBuilder:
128-
for _, list := range values {
102+
103+
switch fieldBuilder := builder.(type) {
104+
105+
case *array.BooleanBuilder:
106+
fieldBuilder.Append(value.GetBoolVal())
107+
case *array.BinaryBuilder:
108+
fieldBuilder.Append(value.GetBytesVal())
109+
case *array.StringBuilder:
110+
fieldBuilder.Append(value.GetStringVal())
111+
case *array.Int32Builder:
112+
fieldBuilder.Append(value.GetInt32Val())
113+
case *array.Int64Builder:
114+
fieldBuilder.Append(value.GetInt64Val())
115+
case *array.Float32Builder:
116+
fieldBuilder.Append(value.GetFloatVal())
117+
case *array.Float64Builder:
118+
fieldBuilder.Append(value.GetDoubleVal())
119+
case *array.TimestampBuilder:
120+
fieldBuilder.Append(arrow.Timestamp(value.GetUnixTimestampVal()))
121+
case *array.ListBuilder:
129122
fieldBuilder.Append(true)
130123

131124
switch valueBuilder := fieldBuilder.ValueBuilder().(type) {
132125

133126
case *array.BooleanBuilder:
134-
for _, v := range list.GetBoolListVal().GetVal() {
127+
for _, v := range value.GetBoolListVal().GetVal() {
135128
valueBuilder.Append(v)
136129
}
137130
case *array.BinaryBuilder:
138-
for _, v := range list.GetBytesListVal().GetVal() {
131+
for _, v := range value.GetBytesListVal().GetVal() {
139132
valueBuilder.Append(v)
140133
}
141134
case *array.StringBuilder:
142-
for _, v := range list.GetStringListVal().GetVal() {
135+
for _, v := range value.GetStringListVal().GetVal() {
143136
valueBuilder.Append(v)
144137
}
145138
case *array.Int32Builder:
146-
for _, v := range list.GetInt32ListVal().GetVal() {
139+
for _, v := range value.GetInt32ListVal().GetVal() {
147140
valueBuilder.Append(v)
148141
}
149142
case *array.Int64Builder:
150-
for _, v := range list.GetInt64ListVal().GetVal() {
143+
for _, v := range value.GetInt64ListVal().GetVal() {
151144
valueBuilder.Append(v)
152145
}
153146
case *array.Float32Builder:
154-
for _, v := range list.GetFloatListVal().GetVal() {
147+
for _, v := range value.GetFloatListVal().GetVal() {
155148
valueBuilder.Append(v)
156149
}
157150
case *array.Float64Builder:
158-
for _, v := range list.GetDoubleListVal().GetVal() {
151+
for _, v := range value.GetDoubleListVal().GetVal() {
159152
valueBuilder.Append(v)
160153
}
161154
case *array.TimestampBuilder:
162-
for _, v := range list.GetUnixTimestampListVal().GetVal() {
155+
for _, v := range value.GetUnixTimestampListVal().GetVal() {
163156
valueBuilder.Append(arrow.Timestamp(v))
164157
}
165158
}
159+
default:
160+
return fmt.Errorf("unsupported array builder: %s", builder)
166161
}
167-
default:
168-
return fmt.Errorf("unsupported array builder: %s", builder)
169162
}
170163
return nil
171164
}
@@ -249,41 +242,68 @@ func ArrowValuesToProtoValues(arr arrow.Array) ([]*types.Value, error) {
249242

250243
switch arr.DataType() {
251244
case arrow.PrimitiveTypes.Int32:
252-
for _, v := range arr.(*array.Int32).Int32Values() {
253-
values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: v}})
245+
for idx := 0; idx < arr.Len(); idx++ {
246+
if arr.IsNull(idx) {
247+
values = append(values, &types.Value{})
248+
} else {
249+
values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: arr.(*array.Int32).Value(idx)}})
250+
}
254251
}
255252
case arrow.PrimitiveTypes.Int64:
256-
for _, v := range arr.(*array.Int64).Int64Values() {
257-
values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: v}})
253+
for idx := 0; idx < arr.Len(); idx++ {
254+
if arr.IsNull(idx) {
255+
values = append(values, &types.Value{})
256+
} else {
257+
values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: arr.(*array.Int64).Value(idx)}})
258+
}
258259
}
259260
case arrow.PrimitiveTypes.Float32:
260-
for _, v := range arr.(*array.Float32).Float32Values() {
261-
values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: v}})
261+
for idx := 0; idx < arr.Len(); idx++ {
262+
if arr.IsNull(idx) {
263+
values = append(values, &types.Value{})
264+
} else {
265+
values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: arr.(*array.Float32).Value(idx)}})
266+
}
262267
}
263268
case arrow.PrimitiveTypes.Float64:
264-
for _, v := range arr.(*array.Float64).Float64Values() {
265-
values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: v}})
269+
for idx := 0; idx < arr.Len(); idx++ {
270+
if arr.IsNull(idx) {
271+
values = append(values, &types.Value{})
272+
} else {
273+
values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: arr.(*array.Float64).Value(idx)}})
274+
}
266275
}
267276
case arrow.FixedWidthTypes.Boolean:
268277
for idx := 0; idx < arr.Len(); idx++ {
269-
values = append(values,
270-
&types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}})
278+
if arr.IsNull(idx) {
279+
values = append(values, &types.Value{})
280+
} else {
281+
values = append(values, &types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}})
282+
}
271283
}
272284
case arrow.BinaryTypes.Binary:
273285
for idx := 0; idx < arr.Len(); idx++ {
274-
values = append(values,
275-
&types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}})
286+
if arr.IsNull(idx) {
287+
values = append(values, &types.Value{})
288+
} else {
289+
values = append(values, &types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}})
290+
}
276291
}
277292
case arrow.BinaryTypes.String:
278293
for idx := 0; idx < arr.Len(); idx++ {
279-
values = append(values,
280-
&types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}})
294+
if arr.IsNull(idx) {
295+
values = append(values, &types.Value{})
296+
} else {
297+
values = append(values, &types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}})
298+
}
281299
}
282300
case arrow.FixedWidthTypes.Timestamp_s:
283301
for idx := 0; idx < arr.Len(); idx++ {
284-
values = append(values,
285-
&types.Value{Val: &types.Value_UnixTimestampVal{
286-
UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}})
302+
if arr.IsNull(idx) {
303+
values = append(values, &types.Value{})
304+
} else {
305+
values = append(values, &types.Value{Val: &types.Value_UnixTimestampVal{UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}})
306+
}
287307
}
288308
case arrow.Null:
289309
for idx := 0; idx < arr.Len(); idx++ {
@@ -306,7 +326,9 @@ func ProtoValuesToArrowArray(protoValues []*types.Value, arrowAllocator memory.A
306326
if err != nil {
307327
return nil, err
308328
}
309-
break
329+
if fieldType != nil {
330+
break
331+
}
310332
}
311333
}
312334

go/types/typeconversion_test.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,46 @@
11
package types
22

33
import (
4+
"math"
45
"testing"
56
"time"
67

78
"github.com/apache/arrow/go/v8/arrow/memory"
8-
"github.com/golang/protobuf/proto"
99
"github.com/stretchr/testify/assert"
10+
"google.golang.org/protobuf/proto"
1011

1112
"github.com/feast-dev/feast/go/protos/feast/types"
1213
)
1314

15+
var nil_or_null_val = &types.Value{}
16+
1417
var (
1518
PROTO_VALUES = [][]*types.Value{
19+
{{Val: nil}},
20+
{{Val: nil}, {Val: nil}},
21+
{nil_or_null_val, nil_or_null_val},
22+
{nil_or_null_val, {Val: nil}},
23+
{{Val: &types.Value_Int32Val{10}}, {Val: nil}, nil_or_null_val, {Val: &types.Value_Int32Val{20}}},
24+
{{Val: &types.Value_Int32Val{10}}, nil_or_null_val},
25+
{nil_or_null_val, {Val: &types.Value_Int32Val{20}}},
1626
{{Val: &types.Value_Int32Val{10}}, {Val: &types.Value_Int32Val{20}}},
27+
{{Val: &types.Value_Int64Val{10}}, nil_or_null_val},
1728
{{Val: &types.Value_Int64Val{10}}, {Val: &types.Value_Int64Val{20}}},
29+
{nil_or_null_val, {Val: &types.Value_FloatVal{2.0}}},
1830
{{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}},
31+
{{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}, {Val: &types.Value_FloatVal{float32(math.NaN())}}},
1932
{{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}},
33+
{{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}, {Val: &types.Value_DoubleVal{math.NaN()}}},
34+
{{Val: &types.Value_DoubleVal{1.0}}, nil_or_null_val},
35+
{nil_or_null_val, {Val: &types.Value_StringVal{"bbb"}}},
2036
{{Val: &types.Value_StringVal{"aaa"}}, {Val: &types.Value_StringVal{"bbb"}}},
37+
{{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, nil_or_null_val},
2138
{{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, {Val: &types.Value_BytesVal{[]byte{4, 5, 6}}}},
39+
{nil_or_null_val, {Val: &types.Value_BoolVal{false}}},
2240
{{Val: &types.Value_BoolVal{true}}, {Val: &types.Value_BoolVal{false}}},
23-
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}},
24-
{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}},
41+
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, nil_or_null_val},
42+
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}},
43+
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{-9223372036854775808}}},
2544

2645
{
2746
{Val: &types.Value_Int32ListVal{&types.Int32List{Val: []int32{0, 1, 2}}}},
@@ -55,6 +74,11 @@ var (
5574
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix()}}}},
5675
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix()}}}},
5776
},
77+
{
78+
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix(), time.Now().Unix()}}}},
79+
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix(), time.Now().Unix()}}}},
80+
{Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{-9223372036854775808, time.Now().Unix()}}}},
81+
},
5882
}
5983
)
6084

0 commit comments

Comments
 (0)