|
7 | 7 |
|
8 | 8 | import pyarrow as pa |
9 | 9 | import pyarrow.flight as fl |
| 10 | +from google.protobuf.json_format import Parse |
10 | 11 |
|
11 | 12 | from feast import FeatureStore, FeatureView, utils |
12 | 13 | from feast.arrow_error_handler import arrow_server_error_handling_decorator |
| 14 | +from feast.data_source import DataSource |
13 | 15 | from feast.feature_logging import FeatureServiceLoggingSource |
14 | 16 | from feast.feature_view import DUMMY_ENTITY_NAME |
15 | 17 | from feast.infra.offline_stores.offline_utils import get_offline_store_from_config |
|
26 | 28 | init_security_manager, |
27 | 29 | str_to_auth_manager_type, |
28 | 30 | ) |
| 31 | +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto |
29 | 32 | from feast.saved_dataset import SavedDatasetStorage |
30 | 33 |
|
31 | 34 | logger = logging.getLogger(__name__) |
@@ -138,6 +141,9 @@ def _call_api(self, api: str, command: dict, key: str): |
138 | 141 | elif api == OfflineServer.persist.__name__: |
139 | 142 | self.persist(command, key) |
140 | 143 | remove_data = True |
| 144 | + elif api == OfflineServer.validate_data_source.__name__: |
| 145 | + self.validate_data_source(command) |
| 146 | + remove_data = True |
141 | 147 | except Exception as e: |
142 | 148 | remove_data = True |
143 | 149 | logger.exception(e) |
@@ -224,6 +230,11 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): |
224 | 230 | table = self.pull_all_from_table_or_query(command).to_arrow() |
225 | 231 | elif api == OfflineServer.pull_latest_from_table_or_query.__name__: |
226 | 232 | table = self.pull_latest_from_table_or_query(command).to_arrow() |
| 233 | + elif ( |
| 234 | + api |
| 235 | + == OfflineServer.get_table_column_names_and_types_from_data_source.__name__ |
| 236 | + ): |
| 237 | + table = self.get_table_column_names_and_types_from_data_source(command) |
227 | 238 | else: |
228 | 239 | raise NotImplementedError |
229 | 240 | except Exception as e: |
@@ -457,6 +468,41 @@ def persist(self, command: dict, key: str): |
457 | 468 | traceback.print_exc() |
458 | 469 | raise e |
459 | 470 |
|
| 471 | + @staticmethod |
| 472 | + def _extract_data_source_from_command(command) -> DataSource: |
| 473 | + data_source_proto_str = command["data_source_proto"] |
| 474 | + logger.debug(f"Extracted data_source_proto {data_source_proto_str}") |
| 475 | + data_source_proto = DataSourceProto() |
| 476 | + Parse(data_source_proto_str, data_source_proto) |
| 477 | + data_source = DataSource.from_proto(data_source_proto) |
| 478 | + logger.debug(f"Converted to DataSource {data_source}") |
| 479 | + return data_source |
| 480 | + |
| 481 | + def validate_data_source(self, command: dict): |
| 482 | + data_source = OfflineServer._extract_data_source_from_command(command) |
| 483 | + logger.debug(f"Validating data source {data_source.name}") |
| 484 | + assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE]) |
| 485 | + |
| 486 | + self.offline_store.validate_data_source( |
| 487 | + config=self.store.config, |
| 488 | + data_source=data_source, |
| 489 | + ) |
| 490 | + |
| 491 | + def get_table_column_names_and_types_from_data_source(self, command: dict): |
| 492 | + data_source = OfflineServer._extract_data_source_from_command(command) |
| 493 | + logger.debug(f"Fetching table columns metadata from {data_source.name}") |
| 494 | + assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE]) |
| 495 | + |
| 496 | + column_names_and_types = data_source.get_table_column_names_and_types( |
| 497 | + self.store.config |
| 498 | + ) |
| 499 | + |
| 500 | + column_names, types = zip(*column_names_and_types) |
| 501 | + logger.debug( |
| 502 | + f"DataSource {data_source.name} has columns {column_names} with types {types}" |
| 503 | + ) |
| 504 | + return pa.table({"name": column_names, "type": types}) |
| 505 | + |
460 | 506 |
|
461 | 507 | def remove_dummies(fv: FeatureView) -> FeatureView: |
462 | 508 | """ |
|
0 commit comments