|
1 | 1 | import logging |
2 | 2 | import os |
3 | 3 | import tempfile |
| 4 | +from datetime import datetime, timedelta |
4 | 5 | from textwrap import dedent |
5 | 6 |
|
6 | 7 | import pytest |
7 | 8 |
|
8 | | -from feast import FeatureView, OnDemandFeatureView, StreamFeatureView |
| 9 | +from feast import ( |
| 10 | + Entity, |
| 11 | + FeatureView, |
| 12 | + Field, |
| 13 | + FileSource, |
| 14 | + OnDemandFeatureView, |
| 15 | + StreamFeatureView, |
| 16 | +) |
9 | 17 | from feast.feature_store import FeatureStore |
| 18 | +from feast.infra.online_stores.remote import ( |
| 19 | + RemoteOnlineStoreConfig, |
| 20 | +) |
10 | 21 | from feast.permissions.action import AuthzedAction |
11 | 22 | from feast.permissions.permission import Permission |
12 | 23 | from feast.permissions.policy import RoleBasedPolicy |
| 24 | +from feast.protos.feast.types.entity_key_pb2 import EntityKey as EntityKeyProto |
| 25 | +from feast.protos.feast.types.value_pb2 import Value as ValueProto |
| 26 | +from feast.types import Float32, Int64 |
13 | 27 | from tests.utils.auth_permissions_util import ( |
14 | 28 | PROJECT_NAME, |
15 | 29 | default_store, |
@@ -265,3 +279,156 @@ def _overwrite_remote_client_feature_store_yaml( |
265 | 279 |
|
266 | 280 | with open(repo_config, "w") as repo_config_file: |
267 | 281 | repo_config_file.write(config_content) |
| 282 | + |
| 283 | + |
| 284 | +@pytest.mark.integration |
| 285 | +@pytest.mark.rbac_remote_integration_test |
| 286 | +@pytest.mark.parametrize( |
| 287 | + "tls_mode", [("True", "True"), ("True", "False"), ("False", "")], indirect=True |
| 288 | +) |
| 289 | +def test_remote_online_store_read_write(auth_config, tls_mode): |
| 290 | + with ( |
| 291 | + tempfile.TemporaryDirectory() as remote_server_tmp_dir, |
| 292 | + tempfile.TemporaryDirectory() as remote_client_tmp_dir, |
| 293 | + ): |
| 294 | + permissions_list = [ |
| 295 | + Permission( |
| 296 | + name="online_list_fv_perm", |
| 297 | + types=FeatureView, |
| 298 | + policy=RoleBasedPolicy(roles=["reader"]), |
| 299 | + actions=[AuthzedAction.READ_ONLINE], |
| 300 | + ), |
| 301 | + Permission( |
| 302 | + name="online_list_odfv_perm", |
| 303 | + types=OnDemandFeatureView, |
| 304 | + policy=RoleBasedPolicy(roles=["reader"]), |
| 305 | + actions=[AuthzedAction.READ_ONLINE], |
| 306 | + ), |
| 307 | + Permission( |
| 308 | + name="online_list_sfv_perm", |
| 309 | + types=StreamFeatureView, |
| 310 | + policy=RoleBasedPolicy(roles=["reader"]), |
| 311 | + actions=[AuthzedAction.READ_ONLINE], |
| 312 | + ), |
| 313 | + Permission( |
| 314 | + name="online_write_fv_perm", |
| 315 | + types=FeatureView, |
| 316 | + policy=RoleBasedPolicy(roles=["writer"]), |
| 317 | + actions=[AuthzedAction.WRITE_ONLINE], |
| 318 | + ), |
| 319 | + Permission( |
| 320 | + name="online_write_odfv_perm", |
| 321 | + types=OnDemandFeatureView, |
| 322 | + policy=RoleBasedPolicy(roles=["writer"]), |
| 323 | + actions=[AuthzedAction.WRITE_ONLINE], |
| 324 | + ), |
| 325 | + Permission( |
| 326 | + name="online_write_sfv_perm", |
| 327 | + types=StreamFeatureView, |
| 328 | + policy=RoleBasedPolicy(roles=["writer"]), |
| 329 | + actions=[AuthzedAction.WRITE_ONLINE], |
| 330 | + ), |
| 331 | + ] |
| 332 | + server_store, server_url, registry_path = ( |
| 333 | + _create_server_store_spin_feature_server( |
| 334 | + temp_dir=remote_server_tmp_dir, |
| 335 | + auth_config=auth_config, |
| 336 | + permissions_list=permissions_list, |
| 337 | + tls_mode=tls_mode, |
| 338 | + ) |
| 339 | + ) |
| 340 | + assert None not in (server_store, server_url, registry_path) |
| 341 | + |
| 342 | + client_store = _create_remote_client_feature_store( |
| 343 | + temp_dir=remote_client_tmp_dir, |
| 344 | + server_registry_path=str(registry_path), |
| 345 | + feature_server_url=server_url, |
| 346 | + auth_config=auth_config, |
| 347 | + tls_mode=tls_mode, |
| 348 | + ) |
| 349 | + assert client_store is not None |
| 350 | + |
| 351 | + # Define a simple FeatureView for testing write operations |
| 352 | + driver = Entity(name="driver", description="driver id") |
| 353 | + |
| 354 | + driver_hourly_stats_source = FileSource( |
| 355 | + path="", # Path is not used for online writes in this context |
| 356 | + timestamp_field="event_timestamp", |
| 357 | + created_timestamp_column="created", |
| 358 | + ) |
| 359 | + |
| 360 | + driver_hourly_stats_fv = FeatureView( |
| 361 | + name="driver_hourly_stats", |
| 362 | + entities=[driver], |
| 363 | + ttl=timedelta(days=1), |
| 364 | + features=[ |
| 365 | + Field(name="conv_rate", dtype=Float32), |
| 366 | + Field(name="acc_rate", dtype=Float32), |
| 367 | + Field(name="avg_daily_trips", dtype=Int64), |
| 368 | + ], |
| 369 | + online_store=RemoteOnlineStoreConfig(), # Ensure this FV uses the remote online store |
| 370 | + source=driver_hourly_stats_source, |
| 371 | + tags={}, |
| 372 | + ) |
| 373 | + |
| 374 | + # Apply the feature view to the client store |
| 375 | + client_store.apply([driver, driver_hourly_stats_fv]) |
| 376 | + |
| 377 | + # Prepare data for online write |
| 378 | + entity_key_1 = EntityKeyProto( |
| 379 | + join_keys=["driver_id"], entity_values=[ValueProto(int66=1001)] |
| 380 | + ) |
| 381 | + entity_key_2 = EntityKeyProto( |
| 382 | + join_keys=["driver_id"], entity_values=[ValueProto(int66=1002)] |
| 383 | + ) |
| 384 | + |
| 385 | + feature_values_1 = { |
| 386 | + "conv_rate": ValueProto(float_val=0.8), |
| 387 | + "acc_rate": ValueProto(float_val=0.95), |
| 388 | + "avg_daily_trips": ValueProto(int64_val=50), |
| 389 | + } |
| 390 | + |
| 391 | + feature_values_2 = { |
| 392 | + "conv_rate": ValueProto(float_val=0.7), |
| 393 | + "acc_rate": ValueProto(float_val=0.9), |
| 394 | + "avg_daily_trips": ValueProto(int64_val=45), |
| 395 | + } |
| 396 | + |
| 397 | + now = datetime.utcnow() |
| 398 | + |
| 399 | + data = [ |
| 400 | + (entity_key_1, feature_values_1, now, now), |
| 401 | + (entity_key_2, feature_values_2, now, now), |
| 402 | + ] |
| 403 | + |
| 404 | + # Perform the online write |
| 405 | + client_store.online_write_batch( |
| 406 | + config=client_store.repo_config, |
| 407 | + table=driver_hourly_stats_fv, |
| 408 | + data=data, |
| 409 | + progress=None, |
| 410 | + ) |
| 411 | + |
| 412 | + # Verify the data by reading it back |
| 413 | + # read_entity_keys = [entity_key_1, entity_key_2] |
| 414 | + read_features = [ |
| 415 | + "driver_hourly_stats:conv_rate", |
| 416 | + "driver_hourly_stats:acc_rate", |
| 417 | + "driver_hourly_stats:avg_daily_trips", |
| 418 | + ] |
| 419 | + |
| 420 | + online_features = client_store.get_online_features( |
| 421 | + features=read_features, |
| 422 | + entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], |
| 423 | + ).to_dict() |
| 424 | + |
| 425 | + # Assertions for read data |
| 426 | + assert online_features is not None |
| 427 | + assert len(online_features["driver_id"]) == 2 |
| 428 | + assert online_features["driver_id"] == [1001, 1002] |
| 429 | + assert online_features["driver_hourly_stats:conv_rate"] == [0.8, 0.7] |
| 430 | + assert online_features["driver_hourly_stats:acc_rate"] == [0.95, 0.9] |
| 431 | + assert online_features["driver_hourly_stats:avg_daily_trips"] == [50, 45] |
| 432 | + |
| 433 | + # Clean up the applied feature view from the server store to avoid interference with other tests |
| 434 | + server_store.teardown() |
0 commit comments