Skip to content

Commit 2080fa3

Browse files
authored
Refactor pa_to_feast_value_type (feast-dev#2246)
* Refactor `pa_to_feast_value_type` This refactoring is intented to make it more difficult to forget to add conversion for LIST versions of non-LIST types. Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com> * Tidy up `assert_expected_arrow_types` Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
1 parent 53539cf commit 2080fa3

2 files changed

Lines changed: 37 additions & 37 deletions

File tree

sdk/python/feast/type_map.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import re
1615
from datetime import datetime, timezone
1716
from typing import (
1817
Any,
@@ -416,27 +415,30 @@ def _proto_value_to_value_type(proto_value: ProtoValue) -> ValueType:
416415

417416

418417
def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType:
419-
if re.match(r"^timestamp", pa_type_as_str):
420-
return ValueType.INT64
418+
is_list = False
419+
if pa_type_as_str.startswith("list<item: "):
420+
is_list = True
421+
pa_type_as_str = pa_type_as_str.replace("list<item: ", "").replace(">", "")
421422

422-
type_map = {
423-
"int32": ValueType.INT32,
424-
"int64": ValueType.INT64,
425-
"double": ValueType.DOUBLE,
426-
"float": ValueType.FLOAT,
427-
"string": ValueType.STRING,
428-
"binary": ValueType.BYTES,
429-
"bool": ValueType.BOOL,
430-
"list<item: int32>": ValueType.INT32_LIST,
431-
"list<item: int64>": ValueType.INT64_LIST,
432-
"list<item: double>": ValueType.DOUBLE_LIST,
433-
"list<item: float>": ValueType.FLOAT_LIST,
434-
"list<item: string>": ValueType.STRING_LIST,
435-
"list<item: binary>": ValueType.BYTES_LIST,
436-
"list<item: bool>": ValueType.BOOL_LIST,
437-
"null": ValueType.NULL,
438-
}
439-
return type_map[pa_type_as_str]
423+
if pa_type_as_str.startswith("timestamp"):
424+
value_type = ValueType.UNIX_TIMESTAMP
425+
else:
426+
type_map = {
427+
"int32": ValueType.INT32,
428+
"int64": ValueType.INT64,
429+
"double": ValueType.DOUBLE,
430+
"float": ValueType.FLOAT,
431+
"string": ValueType.STRING,
432+
"binary": ValueType.BYTES,
433+
"bool": ValueType.BOOL,
434+
"null": ValueType.NULL,
435+
}
436+
value_type = type_map[pa_type_as_str]
437+
438+
if is_list:
439+
value_type = ValueType[value_type.name + "_LIST"]
440+
441+
return value_type
440442

441443

442444
def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:

sdk/python/tests/integration/registration/test_universal_types.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2-
import re
32
from dataclasses import dataclass
43
from datetime import datetime, timedelta
54
from typing import Any, Dict, List, Tuple, Union
65

76
import numpy as np
87
import pandas as pd
8+
import pyarrow as pa
99
import pytest
1010

1111
from feast.infra.offline_stores.offline_store import RetrievalJob
@@ -339,23 +339,21 @@ def assert_expected_arrow_types(
339339
historical_features_arrow = historical_features.to_arrow()
340340
print(historical_features_arrow)
341341
feature_list_dtype_to_expected_historical_feature_arrow_type = {
342-
"int32": r"int64",
343-
"int64": r"int64",
344-
"float": r"double",
345-
"string": r"string",
346-
"bool": r"bool",
347-
"datetime": r"timestamp\[.+\]",
342+
"int32": pa.types.is_int64,
343+
"int64": pa.types.is_int64,
344+
"float": pa.types.is_float64,
345+
"string": pa.types.is_string,
346+
"bool": pa.types.is_boolean,
347+
"date": pa.types.is_date,
348+
"datetime": pa.types.is_timestamp,
348349
}
349-
arrow_type = feature_list_dtype_to_expected_historical_feature_arrow_type[
350+
arrow_type_checker = feature_list_dtype_to_expected_historical_feature_arrow_type[
350351
feature_dtype
351352
]
353+
pa_type = historical_features_arrow.schema.field("value").type
354+
352355
if feature_is_list:
353-
assert re.match(
354-
f"list<item: {arrow_type}>",
355-
str(historical_features_arrow.schema.field_by_name("value").type),
356-
)
356+
assert pa.types.is_list(pa_type)
357+
assert arrow_type_checker(pa_type.value_type)
357358
else:
358-
assert re.match(
359-
arrow_type,
360-
str(historical_features_arrow.schema.field_by_name("value").type),
361-
)
359+
assert arrow_type_checker(pa_type)

0 commit comments

Comments
 (0)