diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_type_map.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_type_map.py index a11298e9b81..5deb8b76f59 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_type_map.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_type_map.py @@ -22,6 +22,9 @@ def trino_to_feast_value_type(trino_type_as_str: str) -> ValueType: "boolean": ValueType.BOOL, "real": ValueType.FLOAT, "date": ValueType.STRING, + "binary": ValueType.STRING, + "varbinary": ValueType.STRING, + "json": ValueType.STRING, } _trino_type_as_str: str = trino_type_as_str trino_type_as_str = trino_type_as_str.lower() @@ -36,6 +39,8 @@ def trino_to_feast_value_type(trino_type_as_str: str) -> ValueType: trino_type_as_str = "decimal64" else: trino_type_as_str = "decimal32" + else: + trino_type_as_str = "decimal64" elif trino_type_as_str.startswith("timestamp"): trino_type_as_str = "timestamp" @@ -43,6 +48,9 @@ def trino_to_feast_value_type(trino_type_as_str: str) -> ValueType: elif trino_type_as_str.startswith("varchar"): trino_type_as_str = "varchar" + elif trino_type_as_str.startswith("char"): + trino_type_as_str = "char" + if trino_type_as_str not in type_map: raise ValueError(f"Trino type not supported by feast {_trino_type_as_str}") return type_map[trino_type_as_str] @@ -55,7 +63,11 @@ def pa_to_trino_value_type(pa_type_as_str: str) -> str: trino_type = "{}" if pa_type_as_str.startswith("list"): trino_type = "array<{}>" - pa_type_as_str = re.search(r"^list$", pa_type_as_str).group(1) + match = re.search(r"^list$", pa_type_as_str) + if match: + pa_type_as_str = match.group(1) + else: + return trino_type.format("varchar") if pa_type_as_str.startswith("date"): return trino_type.format("date") @@ -67,7 +79,10 @@ def pa_to_trino_value_type(pa_type_as_str: str) -> str: return trino_type.format("timestamp") if pa_type_as_str.startswith("decimal"): - return trino_type.format(pa_type_as_str) + # PyArrow renders decimal types as decimal128(10, 2) or decimal256(10, 2), + # but Trino expects just decimal(10, 2) + normalized = re.sub(r"^decimal\d+", "decimal", pa_type_as_str) + return trino_type.format(normalized) if pa_type_as_str.startswith("map<"): return trino_type.format("varchar") @@ -92,33 +107,43 @@ def pa_to_trino_value_type(pa_type_as_str: str) -> str: "float": "double", "double": "double", "binary": "binary", + "varbinary": "binary", "string": "varchar", + "char": "varchar", } return trino_type.format(type_map[pa_type_as_str]) -_TRINO_TO_PA_TYPE_MAP = { +_TRINO_TO_PA_TYPE_MAP: Dict[str, pa.DataType] = { "null": pa.null(), "boolean": pa.bool_(), "date": pa.date32(), "tinyint": pa.int8(), "smallint": pa.int16(), "integer": pa.int32(), + "int": pa.int32(), "bigint": pa.int64(), "double": pa.float64(), "binary": pa.binary(), + "varbinary": pa.binary(), "char": pa.string(), + "json": pa.string(), "real": pa.float32(), } +def _trino_array_item_type(trino_type_as_str: str) -> str | None: + if trino_type_as_str.startswith("array(") and trino_type_as_str.endswith(")"): + return trino_type_as_str[6:-1].strip() + return None + + def trino_to_pa_value_type(trino_type_as_str: str) -> pa.DataType: - trino_type_as_str = trino_type_as_str.lower() + trino_type_as_str = trino_type_as_str.lower().strip() - _is_list: bool = False - if trino_type_as_str.startswith("array"): - _is_list = True - trino_type_as_str = re.search(r"^array\((\w+)\)$", trino_type_as_str).group(1) + array_item_type = _trino_array_item_type(trino_type_as_str) + if array_item_type is not None: + return pa.list_(trino_to_pa_value_type(array_item_type)) if trino_type_as_str.startswith("decimal"): search_precision = re.search( @@ -127,20 +152,24 @@ def trino_to_pa_value_type(trino_type_as_str: str) -> pa.DataType: if search_precision: precision = int(search_precision.group(1)) if precision > 32: - pa_type = pa.float64() + return pa.float64() else: - pa_type = pa.float32() + return pa.float32() + return pa.float64() - elif trino_type_as_str.startswith("timestamp"): - pa_type = pa.timestamp("us") + if trino_type_as_str.startswith("timestamp"): + return pa.timestamp("us") - elif trino_type_as_str.startswith("varchar"): - pa_type = pa.string() + if trino_type_as_str.startswith("varchar"): + return pa.string() + + if trino_type_as_str.startswith("char"): + return pa.string() + + if trino_type_as_str.startswith("row("): + return pa.string() - else: - pa_type = _TRINO_TO_PA_TYPE_MAP[trino_type_as_str] + if trino_type_as_str.startswith("map("): + return pa.string() - if _is_list: - return pa.list_(pa_type) - else: - return pa_type + return _TRINO_TO_PA_TYPE_MAP[trino_type_as_str] diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/trino_offline_store/test_trino_type_map.py b/sdk/python/tests/unit/infra/offline_stores/contrib/trino_offline_store/test_trino_type_map.py new file mode 100644 index 00000000000..012f80fea2d --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/trino_offline_store/test_trino_type_map.py @@ -0,0 +1,204 @@ +import pyarrow as pa +import pytest + +from feast import ValueType +from feast.infra.offline_stores.contrib.trino_offline_store.trino_type_map import ( + _trino_array_item_type, + pa_to_trino_value_type, + trino_to_feast_value_type, + trino_to_pa_value_type, +) + + +class TestTrinoArrayItemType: + def test_simple_type(self) -> None: + assert _trino_array_item_type("array(bigint)") == "bigint" + + def test_parameterized_type(self) -> None: + assert _trino_array_item_type("array(varchar(10))") == "varchar(10)" + + def test_parameterized_with_comma(self) -> None: + assert _trino_array_item_type("array(decimal(10, 2))") == "decimal(10, 2)" + + def test_nested_array(self) -> None: + assert _trino_array_item_type("array(array(varchar))") == "array(varchar)" + + def test_complex_row(self) -> None: + assert ( + _trino_array_item_type("array(row(x bigint, y varchar(10)))") + == "row(x bigint, y varchar(10))" + ) + + def test_not_an_array(self) -> None: + assert _trino_array_item_type("varchar") is None + + def test_partial_prefix(self) -> None: + assert _trino_array_item_type("array") is None + assert _trino_array_item_type("array(") is None + + +class TestTrinoToFeastValueType: + def test_simple_types(self) -> None: + assert trino_to_feast_value_type("boolean") == ValueType.BOOL + assert trino_to_feast_value_type("bigint") == ValueType.INT64 + assert trino_to_feast_value_type("integer") == ValueType.INT32 + assert trino_to_feast_value_type("int") == ValueType.INT32 + assert trino_to_feast_value_type("double") == ValueType.DOUBLE + assert trino_to_feast_value_type("real") == ValueType.FLOAT + assert trino_to_feast_value_type("date") == ValueType.STRING + assert trino_to_feast_value_type("tinyint") == ValueType.INT32 + assert trino_to_feast_value_type("smallint") == ValueType.INT32 + + def test_parameterized_varchar(self) -> None: + assert trino_to_feast_value_type("varchar(10)") == ValueType.STRING + + def test_parameterized_char(self) -> None: + assert trino_to_feast_value_type("char(10)") == ValueType.STRING + assert trino_to_feast_value_type("char") == ValueType.STRING + + def test_timestamp_with_precision(self) -> None: + assert trino_to_feast_value_type("timestamp(3)") == ValueType.UNIX_TIMESTAMP + assert trino_to_feast_value_type("timestamp") == ValueType.UNIX_TIMESTAMP + + def test_decimal_with_precision(self) -> None: + assert trino_to_feast_value_type("decimal(10, 2)") == ValueType.FLOAT + assert trino_to_feast_value_type("decimal(38, 2)") == ValueType.DOUBLE + assert trino_to_feast_value_type("decimal(32)") == ValueType.FLOAT + assert trino_to_feast_value_type("decimal(33)") == ValueType.DOUBLE + + def test_bare_decimal(self) -> None: + assert trino_to_feast_value_type("decimal") == ValueType.DOUBLE + + def test_binary_types(self) -> None: + assert trino_to_feast_value_type("binary") == ValueType.STRING + assert trino_to_feast_value_type("varbinary") == ValueType.STRING + + def test_json(self) -> None: + assert trino_to_feast_value_type("json") == ValueType.STRING + + def test_unsupported_type(self) -> None: + with pytest.raises(ValueError, match="Trino type not supported"): + trino_to_feast_value_type("unknown_type") + + +class TestTrinoToPaValueType: + def test_simple_types(self) -> None: + assert trino_to_pa_value_type("boolean") == pa.bool_() + assert trino_to_pa_value_type("bigint") == pa.int64() + assert trino_to_pa_value_type("integer") == pa.int32() + assert trino_to_pa_value_type("int") == pa.int32() + assert trino_to_pa_value_type("double") == pa.float64() + assert trino_to_pa_value_type("real") == pa.float32() + assert trino_to_pa_value_type("date") == pa.date32() + assert trino_to_pa_value_type("tinyint") == pa.int8() + assert trino_to_pa_value_type("smallint") == pa.int16() + + def test_parameterized_varchar(self) -> None: + assert trino_to_pa_value_type("varchar(10)") == pa.string() + + def test_parameterized_char(self) -> None: + assert trino_to_pa_value_type("char(10)") == pa.string() + assert trino_to_pa_value_type("char") == pa.string() + + def test_binary_types(self) -> None: + assert trino_to_pa_value_type("binary") == pa.binary() + assert trino_to_pa_value_type("varbinary") == pa.binary() + + def test_json(self) -> None: + assert trino_to_pa_value_type("json") == pa.string() + + def test_timestamp(self) -> None: + assert trino_to_pa_value_type("timestamp") == pa.timestamp("us") + assert trino_to_pa_value_type("timestamp(3)") == pa.timestamp("us") + + def test_decimal_bare(self) -> None: + assert trino_to_pa_value_type("decimal") == pa.float64() + + def test_decimal_with_precision(self) -> None: + assert trino_to_pa_value_type("decimal(10, 2)") == pa.float32() + assert trino_to_pa_value_type("decimal(38, 2)") == pa.float64() + assert trino_to_pa_value_type("decimal(32)") == pa.float32() + assert trino_to_pa_value_type("decimal(33)") == pa.float64() + + def test_array_simple(self) -> None: + assert trino_to_pa_value_type("array(bigint)") == pa.list_(pa.int64()) + + def test_array_parameterized_varchar(self) -> None: + assert trino_to_pa_value_type("array(varchar(10))") == pa.list_(pa.string()) + + def test_array_parameterized_decimal(self) -> None: + assert trino_to_pa_value_type("array(decimal(10, 2))") == pa.list_(pa.float32()) + + def test_array_nested(self) -> None: + assert trino_to_pa_value_type("array(array(bigint))") == pa.list_( + pa.list_(pa.int64()) + ) + + def test_row_type(self) -> None: + assert trino_to_pa_value_type("row(x bigint)") == pa.string() + assert trino_to_pa_value_type("row(x bigint, y varchar)") == pa.string() + + def test_map_type(self) -> None: + assert trino_to_pa_value_type("map(varchar, bigint)") == pa.string() + + def test_array_of_row(self) -> None: + assert trino_to_pa_value_type( + "array(row(x bigint, y varchar(10)))" + ) == pa.list_(pa.string()) + + def test_unsupported_type(self) -> None: + with pytest.raises(KeyError): + trino_to_pa_value_type("unknown_type") + + +class TestPaToTrinoValueType: + def test_simple_types(self) -> None: + assert pa_to_trino_value_type(str(pa.bool_())) == "boolean" + assert pa_to_trino_value_type(str(pa.int8())) == "tinyint" + assert pa_to_trino_value_type(str(pa.int16())) == "smallint" + assert pa_to_trino_value_type(str(pa.int32())) == "int" + assert pa_to_trino_value_type(str(pa.int64())) == "bigint" + assert pa_to_trino_value_type(str(pa.float32())) == "double" + assert pa_to_trino_value_type(str(pa.float64())) == "double" + assert pa_to_trino_value_type(str(pa.binary())) == "binary" + + def test_string(self) -> None: + assert pa_to_trino_value_type(str(pa.string())) == "varchar" + assert pa_to_trino_value_type("large_string") == "varchar" + assert pa_to_trino_value_type("char") == "varchar" + + def test_varbinary(self) -> None: + assert pa_to_trino_value_type("varbinary") == "binary" + + def test_date(self) -> None: + assert pa_to_trino_value_type(str(pa.date32())) == "date" + + def test_timestamp(self) -> None: + assert pa_to_trino_value_type(str(pa.timestamp("us"))) == "timestamp" + assert ( + pa_to_trino_value_type(str(pa.timestamp("us", tz="UTC"))) + == "timestamp with time zone" + ) + + def test_decimal128(self) -> None: + assert pa_to_trino_value_type(str(pa.decimal128(10, 2))) == "decimal(10, 2)" + + def test_decimal256(self) -> None: + assert pa_to_trino_value_type(str(pa.decimal256(10, 2))) == "decimal(10, 2)" + + def test_list(self) -> None: + assert pa_to_trino_value_type(str(pa.list_(pa.int64()))) == "array" + + def test_list_of_string(self) -> None: + assert pa_to_trino_value_type(str(pa.list_(pa.string()))) == "array" + + def test_map_degrades_to_varchar(self) -> None: + type_str = str(pa.map_(pa.string(), pa.int64())) + assert pa_to_trino_value_type(type_str) == "varchar" + + def test_struct_degrades_to_varchar(self) -> None: + type_str = str(pa.struct([("x", pa.int64())])) + assert pa_to_trino_value_type(type_str) == "varchar" + + def test_null(self) -> None: + assert pa_to_trino_value_type(str(pa.null())) == "null"