Skip to content

Commit 43b2897

Browse files
committed
fix: Handle PyTorch for CPU-only systems
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent 77e7d6f commit 43b2897

File tree

6 files changed

+73
-14
lines changed

6 files changed

+73
-14
lines changed

.github/workflows/unit_tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ jobs:
3030
uses: astral-sh/setup-uv@v5
3131
with:
3232
enable-cache: true
33+
- name: Install torch (platform-specific)
34+
run: |
35+
if [[ "$RUNNER_OS" == "Linux" ]]; then
36+
pip install torch==2.2.2+cpu torchvision==0.17.2+cpu \
37+
-f https://download.pytorch.org/whl/torch_stable.html
38+
fi
3339
- name: Install dependencies
3440
run: make install-python-dependencies-ci
3541
- name: Test Python

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import pandas as pd
3131
import pyarrow
32-
import torch
3332

3433
from feast import flags_helper
3534
from feast.data_source import DataSource
@@ -40,6 +39,7 @@
4039
from feast.on_demand_feature_view import OnDemandFeatureView
4140
from feast.repo_config import RepoConfig
4241
from feast.saved_dataset import SavedDatasetStorage
42+
from feast.torch_wrapper import get_torch
4343

4444
if TYPE_CHECKING:
4545
from feast.saved_dataset import ValidationReference
@@ -160,13 +160,15 @@ def to_tensor(
160160
raise ValueError(
161161
f"Unsupported tensor kind: {kind}. Only 'torch' is supported."
162162
)
163+
torch = get_torch()
164+
device = "cuda" if torch.cuda.is_available() else "cpu"
163165
df = self.to_df(timeout=timeout)
164166
tensor_dict = {}
165167
for column in df.columns:
166168
values = df[column].fillna(default_value).tolist()
167169
first_non_null = next((v for v in values if v is not None), None)
168170
if isinstance(first_non_null, (int, float, bool)):
169-
tensor_dict[column] = torch.tensor(values)
171+
tensor_dict[column] = torch.tensor(values, device=device)
170172
else:
171173
tensor_dict[column] = values
172174
return tensor_dict

sdk/python/feast/online_response.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Union
15+
from typing import TYPE_CHECKING, Any, Dict, List, Union
1616

1717
import pandas as pd
1818
import pyarrow as pa
19-
import torch
2019

2120
from feast.feature_view import DUMMY_ENTITY_ID
2221
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
22+
from feast.torch_wrapper import get_torch
2323
from feast.type_map import feast_value_type_to_python_type
2424

25+
if TYPE_CHECKING:
26+
import torch
27+
28+
TorchTensor = torch.Tensor
29+
else:
30+
TorchTensor = Any
31+
2532
TIMESTAMP_POSTFIX: str = "__ts"
2633

2734

@@ -94,7 +101,7 @@ def to_tensor(
94101
self,
95102
kind: str = "torch",
96103
default_value: Any = float("nan"),
97-
) -> Dict[str, Union[torch.Tensor, List[Any]]]:
104+
) -> Dict[str, Union[TorchTensor, List[Any]]]:
98105
"""
99106
Converts GetOnlineFeaturesResponse features into a dictionary of tensors or lists.
100107
@@ -112,17 +119,18 @@ def to_tensor(
112119
raise ValueError(
113120
f"Unsupported tensor kind: {kind}. Only 'torch' is supported currently."
114121
)
115-
122+
torch = get_torch()
116123
feature_dict = self.to_dict(include_event_timestamps=False)
117124
feature_keys = set(self.proto.metadata.feature_names.val)
118-
tensor_dict: Dict[str, Union[torch.Tensor, List[Any]]] = {}
125+
tensor_dict: Dict[str, Union[TorchTensor, List[Any]]] = {}
119126
for key in feature_keys:
120127
raw_values = feature_dict[key]
121128
values = [v if v is not None else default_value for v in raw_values]
122129
first_valid = next((v for v in values if v is not None), None)
123130
if isinstance(first_valid, (int, float, bool)):
124131
try:
125-
tensor_dict[key] = torch.tensor(values)
132+
device = "cuda" if torch.cuda.is_available() else "cpu"
133+
tensor_dict[key] = torch.tensor(values, device=device)
126134
except Exception as e:
127135
raise ValueError(
128136
f"Failed to convert values for '{key}' to tensor: {e}"

sdk/python/feast/torch_wrapper.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import importlib
2+
3+
TORCH_AVAILABLE = False
4+
_torch = None
5+
_torch_import_error = None
6+
7+
8+
def _import_torch():
9+
global _torch, TORCH_AVAILABLE, _torch_import_error
10+
try:
11+
_torch = importlib.import_module("torch")
12+
TORCH_AVAILABLE = True
13+
except Exception as e:
14+
# Catch import errors including CUDA lib missing
15+
TORCH_AVAILABLE = False
16+
_torch_import_error = e
17+
18+
19+
_import_torch()
20+
21+
22+
def get_torch():
23+
"""
24+
Return the torch module if available, else raise a friendly error.
25+
26+
This prevents crashing on import if CUDA libs are missing.
27+
"""
28+
if TORCH_AVAILABLE:
29+
return _torch
30+
else:
31+
error_message = (
32+
"Torch is not available or failed to import.\n"
33+
"Original error:\n"
34+
f"{_torch_import_error}\n\n"
35+
"If you are on a CPU-only system, make sure you install the CPU-only torch wheel:\n"
36+
" pip install torch==2.2.2+cpu torchvision==0.17.2+cpu -f https://download.pytorch.org/whl/torch_stable.html\n"
37+
"Or check your CUDA installation if using GPU torch.\n"
38+
)
39+
raise ImportError(error_message) from _torch_import_error

sdk/python/tests/unit/online_store/test_online_retrieval.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import pandas as pd
1111
import pytest
1212
import sqlite_vec
13-
import torch
1413
from pandas.testing import assert_frame_equal
1514

1615
from feast import FeatureStore, RepoConfig
@@ -19,6 +18,7 @@
1918
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
2019
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
2120
from feast.repo_config import RegistryConfig
21+
from feast.torch_wrapper import get_torch
2222
from feast.types import ValueType
2323
from feast.utils import _utc_now
2424
from tests.integration.feature_repos.universal.feature_views import TAGS
@@ -148,16 +148,19 @@ def test_get_online_features() -> None:
148148
assert "avg_orders_day" in tensor_result
149149
assert "name" in tensor_result
150150
assert "trips" in tensor_result
151-
152151
# Entity values
153-
assert torch.equal(tensor_result["driver_id"], torch.tensor([1, 1]))
152+
torch = get_torch()
153+
device = "cuda" if torch.cuda.is_available() else "cpu"
154+
assert torch.equal(
155+
tensor_result["driver_id"], torch.tensor([1, 1], device=device)
156+
)
154157
assert tensor_result["customer_id"] == ["5", "5"]
155158

156159
# Feature values
157160
assert tensor_result["lon"] == ["1.0", "1.0"] # String -> not tensor
158161
assert torch.equal(tensor_result["avg_orders_day"], torch.tensor([1.0, 1.0]))
159162
assert tensor_result["name"] == ["John", "John"]
160-
assert torch.equal(tensor_result["trips"], torch.tensor([7, 7]))
163+
assert torch.equal(tensor_result["trips"], torch.tensor([7, 7], device=device))
161164

162165
# Ensure features are still in result when keys not found
163166
result = store.get_online_features(

sdk/python/tests/unit/test_offline_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pyarrow as pa
88
import pyarrow.flight as flight
99
import pytest
10-
import torch
1110

1211
from feast import FeatureStore, FeatureView, FileSource
1312
from feast.errors import FeatureViewNotFoundException
@@ -18,6 +17,7 @@
1817
)
1918
from feast.offline_server import OfflineServer, _init_auth_manager
2019
from feast.repo_config import RepoConfig
20+
from feast.torch_wrapper import get_torch
2121
from tests.utils.cli_repo_creator import CliRunner
2222

2323
PROJECT_NAME = "test_remote_offline"
@@ -218,7 +218,7 @@ def _test_get_historical_features_to_tensor(fs: FeatureStore):
218218

219219
assertpy.assert_that(tensor_data).is_not_none()
220220
assertpy.assert_that(tensor_data["driver_id"].shape[0]).is_equal_to(3)
221-
221+
torch = get_torch()
222222
for key, values in tensor_data.items():
223223
if isinstance(values, torch.Tensor):
224224
assertpy.assert_that(values.shape[0]).is_equal_to(3)
@@ -280,6 +280,7 @@ def _test_get_historical_features_to_tensor_with_nan(fs: FeatureStore):
280280
assert "conv_rate" in tensor_data
281281
values = tensor_data["conv_rate"]
282282
# conv_rate is a float feature, missing values should be NaN
283+
torch = get_torch()
283284
for val in values:
284285
assert isinstance(val, torch.Tensor) or torch.is_tensor(val)
285286
assertpy.assert_that(torch.isnan(val).item()).is_true()

0 commit comments

Comments
 (0)