diff --git a/sdk/python/feast/types.py b/sdk/python/feast/types.py index fe74bd38bd..7542c73613 100644 --- a/sdk/python/feast/types.py +++ b/sdk/python/feast/types.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from enum import Enum -from typing import Union +from typing import Dict, Union from feast.protos.feast.types.Value_pb2 import ValueType as ValueTypeProto @@ -40,17 +40,22 @@ def __init__(self): pass @abstractmethod - def to_int(self) -> int: + def to_value_type(self) -> ValueTypeProto.Enum: """ - Converts a ComplexFeastType object to the appropriate int value corresponding to - the correct ValueTypeProto.Enum value. + Converts a ComplexFeastType object to the corresponding ValueTypeProto.Enum value. """ raise NotImplementedError + def __eq__(self, other): + return self.to_value_type() == other.to_value_type() + class PrimitiveFeastType(Enum): """ A PrimitiveFeastType represents a primitive type in Feast. + + Note that these values must match the values in ValueTypeProto.Enum. See + /feast/protos/types/Value.proto for the exact values. """ INVALID = 0 @@ -58,12 +63,15 @@ class PrimitiveFeastType(Enum): STRING = 2 INT32 = 3 INT64 = 4 - FLOAT32 = 5 - FLOAT64 = 6 + FLOAT64 = 5 + FLOAT32 = 6 BOOL = 7 UNIX_TIMESTAMP = 8 - def to_int(self) -> int: + def to_value_type(self) -> ValueTypeProto.Enum: + """ + Converts a PrimitiveFeastType object to the corresponding ValueTypeProto.Enum value. + """ value_type_name = PRIMITIVE_FEAST_TYPES_TO_VALUE_TYPES[self.name] return ValueTypeProto.Enum.Value(value_type_name) @@ -110,8 +118,49 @@ def __init__(self, base_type: Union[PrimitiveFeastType, ComplexFeastType]): self.base_type = base_type - def to_int(self) -> int: + def to_value_type(self) -> int: assert isinstance(self.base_type, PrimitiveFeastType) value_type_name = PRIMITIVE_FEAST_TYPES_TO_VALUE_TYPES[self.base_type.name] value_type_list_name = value_type_name + "_LIST" return ValueTypeProto.Enum.Value(value_type_list_name) + + +VALUE_TYPES_TO_FEAST_TYPES: Dict[ + "ValueTypeProto.Enum", Union[ComplexFeastType, PrimitiveFeastType] +] = { + ValueTypeProto.Enum.INVALID: Invalid, + ValueTypeProto.Enum.BYTES: Bytes, + ValueTypeProto.Enum.STRING: String, + ValueTypeProto.Enum.INT32: Int32, + ValueTypeProto.Enum.INT64: Int64, + ValueTypeProto.Enum.DOUBLE: Float64, + ValueTypeProto.Enum.FLOAT: Float32, + ValueTypeProto.Enum.BOOL: Bool, + ValueTypeProto.Enum.UNIX_TIMESTAMP: UnixTimestamp, + ValueTypeProto.Enum.BYTES_LIST: Array(Bytes), + ValueTypeProto.Enum.STRING_LIST: Array(String), + ValueTypeProto.Enum.INT32_LIST: Array(Int32), + ValueTypeProto.Enum.INT64_LIST: Array(Int64), + ValueTypeProto.Enum.DOUBLE_LIST: Array(Float64), + ValueTypeProto.Enum.FLOAT_LIST: Array(Float32), + ValueTypeProto.Enum.BOOL_LIST: Array(Bool), + ValueTypeProto.Enum.UNIX_TIMESTAMP_LIST: Array(UnixTimestamp), +} + + +def from_value_type( + value_type: ValueTypeProto.Enum, +) -> Union[ComplexFeastType, PrimitiveFeastType]: + """ + Converts a ValueTypeProto.Enum to a Feast type. + + Args: + value_type: The ValueTypeProto.Enum to be converted. + + Raises: + ValueError: The conversion could not be performed. + """ + if value_type in VALUE_TYPES_TO_FEAST_TYPES: + return VALUE_TYPES_TO_FEAST_TYPES[value_type] + + raise ValueError(f"Could not convert value type {value_type} to FeastType.") diff --git a/sdk/python/tests/unit/test_types.py b/sdk/python/tests/unit/test_types.py index 8252a0e181..5a721737f6 100644 --- a/sdk/python/tests/unit/test_types.py +++ b/sdk/python/tests/unit/test_types.py @@ -1,23 +1,35 @@ import pytest from feast.protos.feast.types.Value_pb2 import ValueType as ValueTypeProto -from feast.types import Array, Float32, String +from feast.types import Array, Float32, String, from_value_type def test_primitive_feast_type(): - assert String.to_int() == ValueTypeProto.Enum.Value("STRING") - assert Float32.to_int() == ValueTypeProto.Enum.Value("FLOAT") + assert String.to_value_type() == ValueTypeProto.Enum.Value("STRING") + assert from_value_type(String.to_value_type()) == String + assert Float32.to_value_type() == ValueTypeProto.Enum.Value("FLOAT") + assert from_value_type(Float32.to_value_type()) == Float32 def test_array_feast_type(): array_float_32 = Array(Float32) - assert array_float_32.to_int() == ValueTypeProto.Enum.Value("FLOAT_LIST") + assert array_float_32.to_value_type() == ValueTypeProto.Enum.Value("FLOAT_LIST") + assert from_value_type(array_float_32.to_value_type()) == array_float_32 array_string = Array(String) - assert array_string.to_int() == ValueTypeProto.Enum.Value("STRING_LIST") + assert array_string.to_value_type() == ValueTypeProto.Enum.Value("STRING_LIST") + assert from_value_type(array_string.to_value_type()) == array_string with pytest.raises(ValueError): _ = Array(Array) with pytest.raises(ValueError): _ = Array(Array(String)) + + +def test_all_value_types(): + values = ValueTypeProto.Enum.values() + for value in values: + # We do not support the NULL type. + if value != ValueTypeProto.Enum.Value("NULL"): + assert from_value_type(value).to_value_type() == value