forked from feast-dev/feast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathazure_provider.py
More file actions
72 lines (61 loc) · 2.34 KB
/
azure_provider.py
File metadata and controls
72 lines (61 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from datetime import datetime
from typing import Callable
from tqdm import tqdm
from feast.feature_view import FeatureView
from feast.infra.passthrough_provider import PassthroughProvider
from feast.infra.registry.base_registry import BaseRegistry
from feast.repo_config import RepoConfig
from feast.utils import (
_convert_arrow_to_proto,
_get_column_names,
_run_pyarrow_field_mapping,
)
DEFAULT_BATCH_SIZE = 10_000
class AzureProvider(PassthroughProvider):
def materialize_single_feature_view(
self,
config: RepoConfig,
feature_view: FeatureView,
start_date: datetime,
end_date: datetime,
registry: BaseRegistry,
project: str,
tqdm_builder: Callable[[int], tqdm],
) -> None:
# TODO(kevjumba): untested
entities = []
for entity_name in feature_view.entities:
entities.append(registry.get_entity(entity_name, project))
(
join_key_columns,
feature_name_columns,
event_timestamp_column,
created_timestamp_column,
) = _get_column_names(feature_view, entities)
offline_job = self.offline_store.pull_latest_from_table_or_query(
config=config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)
table = offline_job.to_arrow()
if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)
join_keys = {entity.join_key: entity.value_type for entity in entities}
with tqdm_builder(table.num_rows) as pbar:
for batch in table.to_batches(DEFAULT_BATCH_SIZE):
rows_to_write = _convert_arrow_to_proto(batch, feature_view, join_keys)
self.online_write_batch(
self.repo_config,
feature_view,
rows_to_write,
lambda x: pbar.update(x),
)