Skip to content

Commit ac62a32

Browse files
authored
fix: Remote apply using offline store (#4559)
* remote apply using offline store Signed-off-by: Daniele Martinoli <dmartino@redhat.com> * passing data source proto to the offline server Signed-off-by: Daniele Martinoli <dmartino@redhat.com> * fixed linting, added permission asserts Signed-off-by: Daniele Martinoli <dmartino@redhat.com> --------- Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
1 parent ba05893 commit ac62a32

File tree

11 files changed

+282
-64
lines changed

11 files changed

+282
-64
lines changed

sdk/python/feast/feature_store.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,16 +602,23 @@ def _make_inferences(
602602

603603
# New feature views may reference previously applied entities.
604604
entities = self._list_entities()
605+
provider = self._get_provider()
605606
update_feature_views_with_inferred_features_and_entities(
606-
views_to_update, entities + entities_to_update, self.config
607+
provider,
608+
views_to_update,
609+
entities + entities_to_update,
610+
self.config,
607611
)
608612
update_feature_views_with_inferred_features_and_entities(
609-
sfvs_to_update, entities + entities_to_update, self.config
613+
provider,
614+
sfvs_to_update,
615+
entities + entities_to_update,
616+
self.config,
610617
)
611618
# We need to attach the time stamp fields to the underlying data sources
612619
# and cascade the dependencies
613620
update_feature_views_with_inferred_features_and_entities(
614-
odfvs_to_update, entities + entities_to_update, self.config
621+
provider, odfvs_to_update, entities + entities_to_update, self.config
615622
)
616623
# TODO(kevjumba): Update schema inference
617624
for sfv in sfvs_to_update:
@@ -1529,9 +1536,12 @@ def write_to_offline_store(
15291536
feature_view_name, allow_registry_cache=allow_registry_cache
15301537
)
15311538

1539+
provider = self._get_provider()
15321540
# Get columns of the batch source and the input dataframe.
15331541
column_names_and_types = (
1534-
feature_view.batch_source.get_table_column_names_and_types(self.config)
1542+
provider.get_table_column_names_and_types_from_data_source(
1543+
self.config, feature_view.batch_source
1544+
)
15351545
)
15361546
source_columns = [column for column, _ in column_names_and_types]
15371547
input_columns = df.columns.values.tolist()
@@ -1545,7 +1555,6 @@ def write_to_offline_store(
15451555
df = df.reindex(columns=source_columns)
15461556

15471557
table = pa.Table.from_pandas(df)
1548-
provider = self._get_provider()
15491558
provider.ingest_df_to_offline_store(feature_view, table)
15501559

15511560
def get_online_features(

sdk/python/feast/inference.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from feast.infra.offline_stores.file_source import FileSource
1414
from feast.infra.offline_stores.redshift_source import RedshiftSource
1515
from feast.infra.offline_stores.snowflake_source import SnowflakeSource
16+
from feast.infra.provider import Provider
1617
from feast.on_demand_feature_view import OnDemandFeatureView
1718
from feast.repo_config import RepoConfig
1819
from feast.stream_feature_view import StreamFeatureView
@@ -95,6 +96,7 @@ def update_data_sources_with_inferred_event_timestamp_col(
9596

9697

9798
def update_feature_views_with_inferred_features_and_entities(
99+
provider: Provider,
98100
fvs: Union[List[FeatureView], List[StreamFeatureView], List[OnDemandFeatureView]],
99101
entities: List[Entity],
100102
config: RepoConfig,
@@ -176,6 +178,7 @@ def update_feature_views_with_inferred_features_and_entities(
176178

177179
if run_inference_for_entities or run_inference_for_features:
178180
_infer_features_and_entities(
181+
provider,
179182
fv,
180183
join_keys,
181184
run_inference_for_features,
@@ -193,6 +196,7 @@ def update_feature_views_with_inferred_features_and_entities(
193196

194197

195198
def _infer_features_and_entities(
199+
provider: Provider,
196200
fv: Union[FeatureView, OnDemandFeatureView],
197201
join_keys: Set[Optional[str]],
198202
run_inference_for_features,
@@ -222,8 +226,10 @@ def _infer_features_and_entities(
222226
columns_to_exclude.remove(mapped_col)
223227
columns_to_exclude.add(original_col)
224228

225-
table_column_names_and_types = fv.batch_source.get_table_column_names_and_types(
226-
config
229+
table_column_names_and_types = (
230+
provider.get_table_column_names_and_types_from_data_source(
231+
config, fv.batch_source
232+
)
227233
)
228234

229235
for col_name, col_datatype in table_column_names_and_types:

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
from abc import ABC
1616
from datetime import datetime
1717
from pathlib import Path
18-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
18+
from typing import (
19+
TYPE_CHECKING,
20+
Any,
21+
Callable,
22+
Iterable,
23+
List,
24+
Optional,
25+
Tuple,
26+
Union,
27+
)
1928

2029
import pandas as pd
2130
import pyarrow
@@ -352,8 +361,8 @@ def offline_write_batch(
352361
"""
353362
raise NotImplementedError
354363

355-
@staticmethod
356364
def validate_data_source(
365+
self,
357366
config: RepoConfig,
358367
data_source: DataSource,
359368
):
@@ -365,3 +374,17 @@ def validate_data_source(
365374
data_source: DataSource object that needs to be validated
366375
"""
367376
data_source.validate(config=config)
377+
378+
def get_table_column_names_and_types_from_data_source(
379+
self,
380+
config: RepoConfig,
381+
data_source: DataSource,
382+
) -> Iterable[Tuple[str, str]]:
383+
"""
384+
Returns the list of column names and raw column types for a DataSource.
385+
386+
Args:
387+
config: Configuration object used to configure a feature store.
388+
data_source: DataSource object
389+
"""
390+
return data_source.get_table_column_names_and_types(config=config)

sdk/python/feast/infra/offline_stores/remote.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import uuid
44
from datetime import datetime
55
from pathlib import Path
6-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
77

88
import numpy as np
99
import pandas as pd
@@ -328,6 +328,57 @@ def offline_write_batch(
328328
entity_df=None,
329329
)
330330

331+
def validate_data_source(
332+
self,
333+
config: RepoConfig,
334+
data_source: DataSource,
335+
):
336+
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
337+
338+
client = build_arrow_flight_client(
339+
config.offline_store.host, config.offline_store.port, config.auth_config
340+
)
341+
342+
api_parameters = {
343+
"data_source_proto": str(data_source),
344+
}
345+
logger.debug(f"validating DataSource {data_source.name}")
346+
_call_put(
347+
api=OfflineStore.validate_data_source.__name__,
348+
api_parameters=api_parameters,
349+
client=client,
350+
table=None,
351+
entity_df=None,
352+
)
353+
354+
def get_table_column_names_and_types_from_data_source(
355+
self, config: RepoConfig, data_source: DataSource
356+
) -> Iterable[Tuple[str, str]]:
357+
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
358+
359+
client = build_arrow_flight_client(
360+
config.offline_store.host, config.offline_store.port, config.auth_config
361+
)
362+
363+
api_parameters = {
364+
"data_source_proto": str(data_source),
365+
}
366+
logger.debug(
367+
f"Calling {OfflineStore.get_table_column_names_and_types_from_data_source.__name__} with {api_parameters}"
368+
)
369+
table = _send_retrieve_remote(
370+
api=OfflineStore.get_table_column_names_and_types_from_data_source.__name__,
371+
api_parameters=api_parameters,
372+
client=client,
373+
table=None,
374+
entity_df=None,
375+
)
376+
377+
logger.debug(
378+
f"get_table_column_names_and_types_from_data_source for {data_source.name}: {table}"
379+
)
380+
return zip(table.column("name").to_pylist(), table.column("type").to_pylist())
381+
331382

332383
def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame):
333384
entity_schema = _get_entity_schema(

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
from datetime import datetime, timedelta
2-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2+
from typing import (
3+
Any,
4+
Callable,
5+
Dict,
6+
Iterable,
7+
List,
8+
Mapping,
9+
Optional,
10+
Sequence,
11+
Tuple,
12+
Union,
13+
)
314

415
import pandas as pd
516
import pyarrow as pa
@@ -455,3 +466,10 @@ def validate_data_source(
455466
data_source: DataSource,
456467
):
457468
self.offline_store.validate_data_source(config=config, data_source=data_source)
469+
470+
def get_table_column_names_and_types_from_data_source(
471+
self, config: RepoConfig, data_source: DataSource
472+
) -> Iterable[Tuple[str, str]]:
473+
return self.offline_store.get_table_column_names_and_types_from_data_source(
474+
config=config, data_source=data_source
475+
)

sdk/python/feast/infra/provider.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
from abc import ABC, abstractmethod
22
from datetime import datetime
33
from pathlib import Path
4-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
4+
from typing import (
5+
Any,
6+
Callable,
7+
Dict,
8+
Iterable,
9+
List,
10+
Mapping,
11+
Optional,
12+
Sequence,
13+
Tuple,
14+
Union,
15+
)
516

617
import pandas as pd
718
import pyarrow
@@ -405,6 +416,19 @@ def validate_data_source(
405416
"""
406417
pass
407418

419+
@abstractmethod
420+
def get_table_column_names_and_types_from_data_source(
421+
self, config: RepoConfig, data_source: DataSource
422+
) -> Iterable[Tuple[str, str]]:
423+
"""
424+
Returns the list of column names and raw column types for a DataSource.
425+
426+
Args:
427+
config: Configuration object used to configure a feature store.
428+
data_source: DataSource object
429+
"""
430+
pass
431+
408432

409433
def get_provider(config: RepoConfig) -> Provider:
410434
if "." not in config.provider:

sdk/python/feast/infra/registry/remote.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from feast.infra.infra_object import Infra
1616
from feast.infra.registry.base_registry import BaseRegistry
1717
from feast.on_demand_feature_view import OnDemandFeatureView
18-
from feast.permissions.auth.auth_type import AuthType
1918
from feast.permissions.auth_model import AuthConfig, NoAuthConfig
2019
from feast.permissions.client.grpc_client_auth_interceptor import (
2120
GrpcClientAuthHeaderInterceptor,
@@ -67,9 +66,8 @@ def __init__(
6766
):
6867
self.auth_config = auth_config
6968
self.channel = grpc.insecure_channel(registry_config.path)
70-
if self.auth_config.type != AuthType.NONE.value:
71-
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
72-
self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor)
69+
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
70+
self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor)
7371
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel)
7472

7573
def close(self):

sdk/python/feast/offline_server.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
import pyarrow as pa
99
import pyarrow.flight as fl
10+
from google.protobuf.json_format import Parse
1011

1112
from feast import FeatureStore, FeatureView, utils
1213
from feast.arrow_error_handler import arrow_server_error_handling_decorator
14+
from feast.data_source import DataSource
1315
from feast.feature_logging import FeatureServiceLoggingSource
1416
from feast.feature_view import DUMMY_ENTITY_NAME
1517
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
@@ -26,6 +28,7 @@
2628
init_security_manager,
2729
str_to_auth_manager_type,
2830
)
31+
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
2932
from feast.saved_dataset import SavedDatasetStorage
3033

3134
logger = logging.getLogger(__name__)
@@ -138,6 +141,9 @@ def _call_api(self, api: str, command: dict, key: str):
138141
elif api == OfflineServer.persist.__name__:
139142
self.persist(command, key)
140143
remove_data = True
144+
elif api == OfflineServer.validate_data_source.__name__:
145+
self.validate_data_source(command)
146+
remove_data = True
141147
except Exception as e:
142148
remove_data = True
143149
logger.exception(e)
@@ -224,6 +230,11 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
224230
table = self.pull_all_from_table_or_query(command).to_arrow()
225231
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
226232
table = self.pull_latest_from_table_or_query(command).to_arrow()
233+
elif (
234+
api
235+
== OfflineServer.get_table_column_names_and_types_from_data_source.__name__
236+
):
237+
table = self.get_table_column_names_and_types_from_data_source(command)
227238
else:
228239
raise NotImplementedError
229240
except Exception as e:
@@ -457,6 +468,41 @@ def persist(self, command: dict, key: str):
457468
traceback.print_exc()
458469
raise e
459470

471+
@staticmethod
472+
def _extract_data_source_from_command(command) -> DataSource:
473+
data_source_proto_str = command["data_source_proto"]
474+
logger.debug(f"Extracted data_source_proto {data_source_proto_str}")
475+
data_source_proto = DataSourceProto()
476+
Parse(data_source_proto_str, data_source_proto)
477+
data_source = DataSource.from_proto(data_source_proto)
478+
logger.debug(f"Converted to DataSource {data_source}")
479+
return data_source
480+
481+
def validate_data_source(self, command: dict):
482+
data_source = OfflineServer._extract_data_source_from_command(command)
483+
logger.debug(f"Validating data source {data_source.name}")
484+
assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE])
485+
486+
self.offline_store.validate_data_source(
487+
config=self.store.config,
488+
data_source=data_source,
489+
)
490+
491+
def get_table_column_names_and_types_from_data_source(self, command: dict):
492+
data_source = OfflineServer._extract_data_source_from_command(command)
493+
logger.debug(f"Fetching table columns metadata from {data_source.name}")
494+
assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE])
495+
496+
column_names_and_types = data_source.get_table_column_names_and_types(
497+
self.store.config
498+
)
499+
500+
column_names, types = zip(*column_names_and_types)
501+
logger.debug(
502+
f"DataSource {data_source.name} has columns {column_names} with types {types}"
503+
)
504+
return pa.table({"name": column_names, "type": types})
505+
460506

461507
def remove_dummies(fv: FeatureView) -> FeatureView:
462508
"""

0 commit comments

Comments
 (0)