Skip to content

Commit a09b7e5

Browse files
committed
fix fv offline
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
1 parent 62e9b3d commit a09b7e5

File tree

8 files changed

+81
-16
lines changed

8 files changed

+81
-16
lines changed

protos/feast/core/FeatureView.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ message FeatureViewSpec {
7474
DataSource stream_source = 9;
7575

7676
// Whether these features should be served online or not
77+
// This is also used to determine whether the features should be written to the online store
7778
bool online = 8;
79+
80+
// Whether these features should be written to the offline store
81+
bool offline = 13;
7882
}
7983

8084
message FeatureViewMeta {

sdk/python/feast/feature_view.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __copy__(self):
236236
schema=self.schema,
237237
tags=self.tags,
238238
online=self.online,
239+
offline=self.offline,
239240
)
240241

241242
# This is deliberately set outside of the FV initialization as we do not have the Entity objects.
@@ -258,6 +259,7 @@ def __eq__(self, other):
258259
sorted(self.entities) != sorted(other.entities)
259260
or self.ttl != other.ttl
260261
or self.online != other.online
262+
or self.offline != other.offline
261263
or self.batch_source != other.batch_source
262264
or self.stream_source != other.stream_source
263265
or sorted(self.entity_columns) != sorted(other.entity_columns)

sdk/python/feast/infra/compute_engines/local/nodes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from datetime import datetime, timedelta
2-
from typing import Optional
2+
from typing import Optional, Union
33

44
import pyarrow as pa
55

6+
from feast import BatchFeatureView, StreamFeatureView
67
from feast.data_source import DataSource
78
from feast.infra.compute_engines.dag.context import ExecutionContext
89
from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue
@@ -208,7 +209,9 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
208209

209210

210211
class LocalOutputNode(LocalNode):
211-
def __init__(self, name: str, feature_view):
212+
def __init__(
213+
self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView]
214+
):
212215
super().__init__(name)
213216
self.feature_view = feature_view
214217

sdk/python/feast/infra/compute_engines/spark/feature_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def build_transformation_node(self, input_node):
7373
return node
7474

7575
def build_output_nodes(self, input_node):
76-
node = SparkWriteNode("output", input_node, self.feature_view)
76+
node = SparkWriteNode("output", self.feature_view)
77+
node.add_input(input_node)
7778
self.nodes.append(node)
7879
return node
7980

sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,9 @@ class SparkWriteNode(DAGNode):
271271
def __init__(
272272
self,
273273
name: str,
274-
input_node: DAGNode,
275274
feature_view: Union[BatchFeatureView, StreamFeatureView],
276275
):
277276
super().__init__(name)
278-
self.add_input(input_node)
279277
self.feature_view = feature_view
280278

281279
def execute(self, context: ExecutionContext) -> DAGValue:
@@ -286,8 +284,6 @@ def execute(self, context: ExecutionContext) -> DAGValue:
286284

287285
# ✅ 1. Write to online or offline store (if enabled)
288286
if self.feature_view.online or self.feature_view.offline:
289-
print("Spark DF count:", spark_df.count())
290-
print("Num partitions:", spark_df.rdd.getNumPartitions())
291287
spark_df.mapInArrow(
292288
lambda x: map_in_arrow(x, serialized_artifacts), spark_df.schema
293289
).count()

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def map_in_arrow(
3737
offline_store,
3838
repo_config,
3939
) = serialized_artifacts.unserialize()
40+
print("write_feature_view", feature_view)
4041

4142
if feature_view.online:
4243
join_key_to_value_type = {
@@ -55,6 +56,7 @@ def map_in_arrow(
5556
progress=lambda x: None,
5657
)
5758
if feature_view.offline:
59+
print("offline_to_write", table)
5860
offline_store.offline_write_batch(
5961
config=repo_config,
6062
feature_view=feature_view,

sdk/python/tests/integration/compute_engines/spark/test_compute.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from tests.integration.feature_repos.universal.online_store.redis import (
3434
RedisOnlineStoreCreator,
3535
)
36-
from tests.utils.e2e_test_validation import _check_offline_and_online_features
3736

3837
now = datetime.now()
3938
today = datetime.today()
@@ -189,6 +188,20 @@ def transform_feature(df: DataFrame) -> DataFrame:
189188

190189
@pytest.mark.integration
191190
def test_spark_compute_engine_materialize():
191+
"""
192+
Test the SparkComputeEngine materialize method.
193+
For the current feature view driver_hourly_stats, The below execution plan:
194+
1. feature data from create_feature_dataset
195+
2. filter by start_time and end_time, that is, the last 2 days
196+
for the driver_id 1001, the data left is row 0
197+
for the driver_id 1002, the data left is row 2
198+
3. apply the transform_feature function to the data
199+
for all features, the value is multiplied by 2
200+
4. write the data to the online store and offline store
201+
202+
Returns:
203+
204+
"""
192205
spark_environment = create_spark_environment()
193206
fs = spark_environment.feature_store
194207
registry = fs.registry
@@ -213,7 +226,7 @@ def transform_feature(df: DataFrame) -> DataFrame:
213226
Field(name="driver_id", dtype=Int32),
214227
],
215228
online=True,
216-
offline=False,
229+
offline=True,
217230
source=data_source,
218231
)
219232

@@ -244,18 +257,62 @@ def tqdm_builder(length):
244257

245258
assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED
246259

247-
_check_offline_and_online_features(
260+
_check_online_features(
248261
fs=fs,
249-
fv=driver_stats_fv,
250-
driver_id=1,
251-
event_timestamp=now,
252-
expected_value=0.3,
262+
driver_id=1001,
263+
feature="driver_hourly_stats:conv_rate",
264+
expected_value=1.6,
253265
full_feature_names=True,
254-
check_offline_store=True,
266+
)
267+
268+
entity_df = create_entity_df()
269+
270+
_check_offline_features(
271+
fs=fs,
272+
feature="driver_hourly_stats:conv_rate",
273+
entity_df=entity_df,
274+
expected_value=1.6,
255275
)
256276
finally:
257277
spark_environment.teardown()
258278

259279

280+
def _check_online_features(
281+
fs,
282+
driver_id,
283+
feature,
284+
expected_value,
285+
full_feature_names: bool = True,
286+
):
287+
online_response = fs.get_online_features(
288+
features=[feature],
289+
entity_rows=[{"driver_id": driver_id}],
290+
full_feature_names=full_feature_names,
291+
).to_dict()
292+
293+
feature_ref = "__".join(feature.split(":"))
294+
295+
assert len(online_response["driver_id"]) == 1
296+
assert online_response["driver_id"][0] == driver_id
297+
assert abs(online_response[feature_ref][0] - expected_value < 1e-6), (
298+
"Transformed result"
299+
)
300+
301+
302+
def _check_offline_features(
303+
fs,
304+
feature,
305+
entity_df,
306+
expected_value,
307+
):
308+
offline_df = fs.get_historical_features(
309+
entity_df=entity_df,
310+
features=[feature],
311+
).to_df()
312+
313+
assert len(offline_df) == 2
314+
assert offline_df["driver_id"].to_list() == [1001, 1002]
315+
316+
260317
if __name__ == "__main__":
261318
test_spark_compute_engine_get_historical_features()

sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_local_output_node():
194194
context = create_context(
195195
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))}
196196
)
197-
node = LocalOutputNode("output")
197+
node = LocalOutputNode("output", MagicMock())
198198
node.add_input(MagicMock())
199199
node.inputs[0].name = "source"
200200
result = node.execute(context)

0 commit comments

Comments
 (0)