1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Any , Dict , List , Union
15+ from typing import TYPE_CHECKING , Any , Dict , List , Union
1616
1717import pandas as pd
1818import pyarrow as pa
19- import torch
2019
2120from feast .feature_view import DUMMY_ENTITY_ID
2221from feast .protos .feast .serving .ServingService_pb2 import GetOnlineFeaturesResponse
22+ from feast .torch_wrapper import get_torch
2323from feast .type_map import feast_value_type_to_python_type
2424
25+ if TYPE_CHECKING :
26+ import torch
27+
28+ TorchTensor = torch .Tensor
29+ else :
30+ TorchTensor = Any
31+
2532TIMESTAMP_POSTFIX : str = "__ts"
2633
2734
@@ -94,7 +101,7 @@ def to_tensor(
94101 self ,
95102 kind : str = "torch" ,
96103 default_value : Any = float ("nan" ),
97- ) -> Dict [str , Union [torch . Tensor , List [Any ]]]:
104+ ) -> Dict [str , Union [TorchTensor , List [Any ]]]:
98105 """
99106 Converts GetOnlineFeaturesResponse features into a dictionary of tensors or lists.
100107
@@ -112,17 +119,18 @@ def to_tensor(
112119 raise ValueError (
113120 f"Unsupported tensor kind: { kind } . Only 'torch' is supported currently."
114121 )
115-
122+ torch = get_torch ()
116123 feature_dict = self .to_dict (include_event_timestamps = False )
117124 feature_keys = set (self .proto .metadata .feature_names .val )
118- tensor_dict : Dict [str , Union [torch . Tensor , List [Any ]]] = {}
125+ tensor_dict : Dict [str , Union [TorchTensor , List [Any ]]] = {}
119126 for key in feature_keys :
120127 raw_values = feature_dict [key ]
121128 values = [v if v is not None else default_value for v in raw_values ]
122129 first_valid = next ((v for v in values if v is not None ), None )
123130 if isinstance (first_valid , (int , float , bool )):
124131 try :
125- tensor_dict [key ] = torch .tensor (values )
132+ device = "cuda" if torch .cuda .is_available () else "cpu"
133+ tensor_dict [key ] = torch .tensor (values , device = device )
126134 except Exception as e :
127135 raise ValueError (
128136 f"Failed to convert values for '{ key } ' to tensor: { e } "
0 commit comments