|
15 | 15 | from feast.on_demand_feature_view import on_demand_feature_view |
16 | 16 | from feast.repo_config import RepoConfig |
17 | 17 | from feast.types import Float32, Float64, Int64, String, UnixTimestamp |
| 18 | +from feast.value_type import ValueType |
18 | 19 | from tests.utils.data_source_test_creator import prep_file_source |
19 | 20 |
|
20 | 21 |
|
@@ -216,6 +217,78 @@ def test_feature_view_inference_respects_basic_inference(): |
216 | 217 | assert len(feature_view_2.entity_columns) == 2 |
217 | 218 |
|
218 | 219 |
|
| 220 | +def test_feature_view_inference_on_entity_value_types(): |
| 221 | + """ |
| 222 | + Tests that feature view inference correctly uses the entity `value_type` attribute. |
| 223 | + """ |
| 224 | + entity1 = Entity( |
| 225 | + name="test1", join_keys=["id_join_key"], value_type=ValueType.INT64 |
| 226 | + ) |
| 227 | + file_source = FileSource(path="some path") |
| 228 | + feature_view_1 = FeatureView( |
| 229 | + name="test1", |
| 230 | + entities=[entity1], |
| 231 | + schema=[Field(name="int64_col", dtype=Int64)], |
| 232 | + source=file_source, |
| 233 | + ) |
| 234 | + |
| 235 | + assert len(feature_view_1.schema) == 1 |
| 236 | + assert len(feature_view_1.features) == 1 |
| 237 | + assert len(feature_view_1.entity_columns) == 0 |
| 238 | + |
| 239 | + update_feature_views_with_inferred_features_and_entities( |
| 240 | + [feature_view_1], |
| 241 | + [entity1], |
| 242 | + RepoConfig( |
| 243 | + provider="local", project="test", entity_key_serialization_version=2 |
| 244 | + ), |
| 245 | + ) |
| 246 | + |
| 247 | + # The schema is only used as a parameter, as is therefore not updated during inference. |
| 248 | + assert len(feature_view_1.schema) == 1 |
| 249 | + |
| 250 | + # Since there is already a feature specified, additional features are not inferred. |
| 251 | + assert len(feature_view_1.features) == 1 |
| 252 | + |
| 253 | + # The single entity column is inferred correctly and has the expected type. |
| 254 | + assert len(feature_view_1.entity_columns) == 1 |
| 255 | + assert feature_view_1.entity_columns[0].dtype == Int64 |
| 256 | + |
| 257 | + |
| 258 | +def test_conflicting_entity_value_types(): |
| 259 | + """ |
| 260 | + Tests that an error is thrown when the entity value types conflict. |
| 261 | + """ |
| 262 | + entity1 = Entity( |
| 263 | + name="test1", join_keys=["id_join_key"], value_type=ValueType.INT64 |
| 264 | + ) |
| 265 | + file_source = FileSource(path="some path") |
| 266 | + |
| 267 | + with pytest.raises(ValueError): |
| 268 | + _ = FeatureView( |
| 269 | + name="test1", |
| 270 | + entities=[entity1], |
| 271 | + schema=[ |
| 272 | + Field(name="int64_col", dtype=Int64), |
| 273 | + Field( |
| 274 | + name="id_join_key", dtype=Float64 |
| 275 | + ), # Conflicts with the defined entity |
| 276 | + ], |
| 277 | + source=file_source, |
| 278 | + ) |
| 279 | + |
| 280 | + # There should be no error here. |
| 281 | + _ = FeatureView( |
| 282 | + name="test1", |
| 283 | + entities=[entity1], |
| 284 | + schema=[ |
| 285 | + Field(name="int64_col", dtype=Int64), |
| 286 | + Field(name="id_join_key", dtype=Int64), # Conflicts with the defined entity |
| 287 | + ], |
| 288 | + source=file_source, |
| 289 | + ) |
| 290 | + |
| 291 | + |
219 | 292 | def test_feature_view_inference_on_entity_columns(simple_dataset_1): |
220 | 293 | """ |
221 | 294 | Tests that feature view inference correctly infers entity columns. |
|
0 commit comments