diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 122dbffa4e7..bb7d835cfa9 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -30,6 +30,12 @@ jobs: uses: astral-sh/setup-uv@v5 with: enable-cache: true + - name: Install torch (platform-specific) + run: | + if [[ "$RUNNER_OS" == "Linux" ]]; then + pip install torch==2.2.2+cpu torchvision==0.17.2+cpu \ + -f https://download.pytorch.org/whl/torch_stable.html + fi - name: Install dependencies run: make install-python-dependencies-ci - name: Test Python diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 73794f67a17..5a59e1c3234 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -19,6 +19,7 @@ TYPE_CHECKING, Any, Callable, + Dict, Iterable, List, Optional, @@ -38,6 +39,7 @@ from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import RepoConfig from feast.saved_dataset import SavedDatasetStorage +from feast.torch_wrapper import get_torch if TYPE_CHECKING: from feast.saved_dataset import ValidationReference @@ -137,6 +139,40 @@ def to_arrow( return features_table + def to_tensor( + self, + kind: str = "torch", + default_value: Any = float("nan"), + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Converts historical features into a dictionary of 1D torch tensors or lists (for non-numeric types). + + Args: + kind: "torch" (default and only supported kind). + default_value: Value to replace missing (None or NaN) entries. + timeout: Optional timeout for query execution. + + Returns: + Dict[str, Union[torch.Tensor, List]]: Feature column name -> tensor or list. + """ + if kind != "torch": + raise ValueError( + f"Unsupported tensor kind: {kind}. Only 'torch' is supported." + ) + torch = get_torch() + device = "cuda" if torch.cuda.is_available() else "cpu" + df = self.to_df(timeout=timeout) + tensor_dict = {} + for column in df.columns: + values = df[column].fillna(default_value).tolist() + first_non_null = next((v for v in values if v is not None), None) + if isinstance(first_non_null, (int, float, bool)): + tensor_dict[column] = torch.tensor(values, device=device) + else: + tensor_dict[column] = values + return tensor_dict + def to_sql(self) -> str: """ Return RetrievalJob generated SQL statement if applicable. diff --git a/sdk/python/feast/online_response.py b/sdk/python/feast/online_response.py index a4e5694127f..967c507c6a0 100644 --- a/sdk/python/feast/online_response.py +++ b/sdk/python/feast/online_response.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Union import pandas as pd import pyarrow as pa from feast.feature_view import DUMMY_ENTITY_ID from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse +from feast.torch_wrapper import get_torch from feast.type_map import feast_value_type_to_python_type +if TYPE_CHECKING: + import torch + + TorchTensor = torch.Tensor +else: + TorchTensor = Any + TIMESTAMP_POSTFIX: str = "__ts" @@ -88,3 +96,47 @@ def to_arrow(self, include_event_timestamps: bool = False) -> pa.Table: """ return pa.Table.from_pydict(self.to_dict(include_event_timestamps)) + + def to_tensor( + self, + kind: str = "torch", + default_value: Any = float("nan"), + ) -> Dict[str, Union[TorchTensor, List[Any]]]: + """ + Converts GetOnlineFeaturesResponse features into a dictionary of tensors or lists. + + - Numeric features (int, float, bool) -> torch.Tensor + - Non-numeric features (e.g., strings) -> list[Any] + + Args: + kind: Backend tensor type. Currently only "torch" is supported. + default_value: Value to substitute for missing (None) entries. + + Returns: + Dict[str, Union[torch.Tensor, List[Any]]]: Mapping of feature names to tensors or lists. + """ + if kind != "torch": + raise ValueError( + f"Unsupported tensor kind: {kind}. Only 'torch' is supported currently." + ) + torch = get_torch() + feature_dict = self.to_dict(include_event_timestamps=False) + feature_keys = set(self.proto.metadata.feature_names.val) + tensor_dict: Dict[str, Union[TorchTensor, List[Any]]] = {} + for key in feature_keys: + raw_values = feature_dict[key] + values = [v if v is not None else default_value for v in raw_values] + first_valid = next((v for v in values if v is not None), None) + if isinstance(first_valid, (int, float, bool)): + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor_dict[key] = torch.tensor(values, device=device) + except Exception as e: + raise ValueError( + f"Failed to convert values for '{key}' to tensor: {e}" + ) + else: + tensor_dict[key] = ( + values # Return as-is for strings or unsupported types + ) + return tensor_dict diff --git a/sdk/python/feast/torch_wrapper.py b/sdk/python/feast/torch_wrapper.py new file mode 100644 index 00000000000..649294edfcd --- /dev/null +++ b/sdk/python/feast/torch_wrapper.py @@ -0,0 +1,39 @@ +import importlib + +TORCH_AVAILABLE = False +_torch = None +_torch_import_error = None + + +def _import_torch(): + global _torch, TORCH_AVAILABLE, _torch_import_error + try: + _torch = importlib.import_module("torch") + TORCH_AVAILABLE = True + except Exception as e: + # Catch import errors including CUDA lib missing + TORCH_AVAILABLE = False + _torch_import_error = e + + +_import_torch() + + +def get_torch(): + """ + Return the torch module if available, else raise a friendly error. + + This prevents crashing on import if CUDA libs are missing. + """ + if TORCH_AVAILABLE: + return _torch + else: + error_message = ( + "Torch is not available or failed to import.\n" + "Original error:\n" + f"{_torch_import_error}\n\n" + "If you are on a CPU-only system, make sure you install the CPU-only torch wheel:\n" + " pip install torch==2.2.2+cpu torchvision==0.17.2+cpu -f https://download.pytorch.org/whl/torch_stable.html\n" + "Or check your CUDA installation if using GPU torch.\n" + ) + raise ImportError(error_message) from _torch_import_error diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index ab785669df1..0615a4ca9d1 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -18,6 +18,7 @@ from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RegistryConfig +from feast.torch_wrapper import get_torch from feast.types import ValueType from feast.utils import _utc_now from tests.integration.feature_repos.universal.feature_views import TAGS @@ -129,6 +130,38 @@ def test_get_online_features() -> None: assert result["name"] == ["John", "John"] assert result["trips"] == [7, 7] + tensor_result = store.get_online_features( + features=[ + "driver_locations:lon", + "customer_profile:avg_orders_day", + "customer_profile:name", + "customer_driver_combined:trips", + ], + entity_rows=[ + {"driver_id": 1, "customer_id": "5"}, + {"driver_id": 1, "customer_id": 5}, + ], + full_feature_names=False, + ).to_tensor() + + assert "lon" in tensor_result + assert "avg_orders_day" in tensor_result + assert "name" in tensor_result + assert "trips" in tensor_result + # Entity values + torch = get_torch() + device = "cuda" if torch.cuda.is_available() else "cpu" + assert torch.equal( + tensor_result["driver_id"], torch.tensor([1, 1], device=device) + ) + assert tensor_result["customer_id"] == ["5", "5"] + + # Feature values + assert tensor_result["lon"] == ["1.0", "1.0"] # String -> not tensor + assert torch.equal(tensor_result["avg_orders_day"], torch.tensor([1.0, 1.0])) + assert tensor_result["name"] == ["John", "John"] + assert torch.equal(tensor_result["trips"], torch.tensor([7, 7], device=device)) + # Ensure features are still in result when keys not found result = store.get_online_features( features=["customer_driver_combined:trips"], @@ -138,6 +171,15 @@ def test_get_online_features() -> None: assert "trips" in result + result = store.get_online_features( + features=["customer_driver_combined:trips"], + entity_rows=[{"driver_id": 0, "customer_id": 0}], + full_feature_names=False, + ).to_tensor() + + assert "trips" in result + assert isinstance(result["trips"], torch.Tensor) + with pytest.raises(KeyError) as excinfo: _ = store.get_online_features( features=["driver_locations:lon"], diff --git a/sdk/python/tests/unit/test_offline_server.py b/sdk/python/tests/unit/test_offline_server.py index e82e2fa6872..74752f6952d 100644 --- a/sdk/python/tests/unit/test_offline_server.py +++ b/sdk/python/tests/unit/test_offline_server.py @@ -17,6 +17,7 @@ ) from feast.offline_server import OfflineServer, _init_auth_manager from feast.repo_config import RepoConfig +from feast.torch_wrapper import get_torch from tests.utils.cli_repo_creator import CliRunner PROJECT_NAME = "test_remote_offline" @@ -115,7 +116,9 @@ def test_remote_offline_store_apis(): fs = remote_feature_store(server) _test_get_historical_features_returns_data(fs) + _test_get_historical_features_to_tensor(fs) _test_get_historical_features_returns_nan(fs) + _test_get_historical_features_to_tensor_with_nan(fs) _test_offline_write_batch(str(temp_dir), fs) _test_write_logged_features(str(temp_dir), fs) _test_pull_latest_from_table_or_query(str(temp_dir), fs) @@ -187,6 +190,44 @@ def _test_get_historical_features_returns_data(fs: FeatureStore): assertpy.assert_that(value).is_not_nan() +def _test_get_historical_features_to_tensor(fs: FeatureStore): + entity_df = pd.DataFrame.from_dict( + { + "driver_id": [1001, 1002, 1003], + "event_timestamp": [ + datetime(2021, 4, 12, 10, 59, 42), + datetime(2021, 4, 12, 8, 12, 10), + datetime(2021, 4, 12, 16, 40, 26), + ], + "label_driver_reported_satisfaction": [1, 5, 3], + "val_to_add": [1, 2, 3], + "val_to_add_2": [10, 20, 30], + } + ) + + features = [ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + "transformed_conv_rate:conv_rate_plus_val1", + "transformed_conv_rate:conv_rate_plus_val2", + ] + + job = fs.get_historical_features(entity_df, features) + tensor_data = job.to_tensor() + + assertpy.assert_that(tensor_data).is_not_none() + assertpy.assert_that(tensor_data["driver_id"].shape[0]).is_equal_to(3) + torch = get_torch() + for key, values in tensor_data.items(): + if isinstance(values, torch.Tensor): + assertpy.assert_that(values.shape[0]).is_equal_to(3) + for val in values: + val_float = val.item() + assertpy.assert_that(val_float).is_instance_of((float, int)) + assertpy.assert_that(val_float).is_not_nan() + + def _test_get_historical_features_returns_nan(fs: FeatureStore): entity_df = pd.DataFrame.from_dict( { @@ -223,6 +264,28 @@ def _test_get_historical_features_returns_nan(fs: FeatureStore): assertpy.assert_that(value).is_nan() +def _test_get_historical_features_to_tensor_with_nan(fs: FeatureStore): + entity_df = pd.DataFrame.from_dict( + { + "driver_id": [9991, 9992], # IDs with no matching features + "event_timestamp": [ + datetime(2021, 4, 12, 10, 59, 42), + datetime(2021, 4, 12, 10, 59, 42), + ], + } + ) + features = ["driver_hourly_stats:conv_rate"] + job = fs.get_historical_features(entity_df, features) + tensor_data = job.to_tensor() + assert "conv_rate" in tensor_data + values = tensor_data["conv_rate"] + # conv_rate is a float feature, missing values should be NaN + torch = get_torch() + for val in values: + assert isinstance(val, torch.Tensor) or torch.is_tensor(val) + assertpy.assert_that(torch.isnan(val).item()).is_true() + + def _test_offline_write_batch(temp_dir, fs: FeatureStore): data_file = os.path.join( temp_dir, fs.project, "feature_repo/data/driver_stats.parquet"