|
7 | 7 | from feast import BatchFeatureView, StreamFeatureView |
8 | 8 | from feast.aggregation import Aggregation |
9 | 9 | from feast.data_source import DataSource |
| 10 | +from feast.infra.common.serde import SerializedArtifacts |
10 | 11 | from feast.infra.compute_engines.dag.context import ExecutionContext |
11 | 12 | from feast.infra.compute_engines.dag.model import DAGFormat |
12 | 13 | from feast.infra.compute_engines.dag.node import DAGNode |
13 | 14 | from feast.infra.compute_engines.dag.value import DAGValue |
14 | | -from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( |
15 | | - _map_by_partition, |
16 | | - _SparkSerializedArtifacts, |
17 | | -) |
| 15 | +from feast.infra.compute_engines.spark.utils import map_in_arrow |
18 | 16 | from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( |
19 | 17 | SparkRetrievalJob, |
20 | 18 | _get_entity_schema, |
21 | 19 | ) |
| 20 | +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( |
| 21 | + SparkSource, |
| 22 | +) |
22 | 23 | from feast.infra.offline_stores.offline_utils import ( |
23 | 24 | infer_event_timestamp_from_entity_df, |
24 | 25 | ) |
@@ -273,30 +274,41 @@ class SparkWriteNode(DAGNode): |
273 | 274 | def __init__( |
274 | 275 | self, |
275 | 276 | name: str, |
276 | | - input_node: DAGNode, |
277 | 277 | feature_view: Union[BatchFeatureView, StreamFeatureView], |
278 | 278 | ): |
279 | 279 | super().__init__(name) |
280 | | - self.add_input(input_node) |
281 | 280 | self.feature_view = feature_view |
282 | 281 |
|
283 | 282 | def execute(self, context: ExecutionContext) -> DAGValue: |
284 | 283 | spark_df: DataFrame = self.get_single_input_value(context).data |
285 | | - spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( |
| 284 | + serialized_artifacts = SerializedArtifacts.serialize( |
286 | 285 | feature_view=self.feature_view, repo_config=context.repo_config |
287 | 286 | ) |
288 | 287 |
|
289 | | - # ✅ 1. Write to offline store (if enabled) |
290 | | - if self.feature_view.offline: |
291 | | - # TODO: Update _map_by_partition to be able to write to offline store |
292 | | - pass |
293 | | - |
294 | | - # ✅ 2. Write to online store (if enabled) |
| 288 | + # ✅ 1. Write to online store if online enabled |
295 | 289 | if self.feature_view.online: |
296 | | - spark_df.mapInPandas( |
297 | | - lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" |
| 290 | + spark_df.mapInArrow( |
| 291 | + lambda x: map_in_arrow(x, serialized_artifacts, mode="online"), |
| 292 | + spark_df.schema, |
298 | 293 | ).count() |
299 | 294 |
|
| 295 | + # ✅ 2. Write to offline store if offline enabled |
| 296 | + if self.feature_view.offline: |
| 297 | + if not isinstance(self.feature_view.batch_source, SparkSource): |
| 298 | + spark_df.mapInArrow( |
| 299 | + lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"), |
| 300 | + spark_df.schema, |
| 301 | + ).count() |
| 302 | + # Directly write spark df to spark offline store without using mapInArrow |
| 303 | + else: |
| 304 | + dest_path = self.feature_view.batch_source.path |
| 305 | + file_format = self.feature_view.batch_source.file_format |
| 306 | + if not dest_path or not file_format: |
| 307 | + raise ValueError( |
| 308 | + "Destination path and file format must be specified for SparkSource." |
| 309 | + ) |
| 310 | + spark_df.write.format(file_format).mode("append").save(dest_path) |
| 311 | + |
300 | 312 | return DAGValue( |
301 | 313 | data=spark_df, |
302 | 314 | format=DAGFormat.SPARK, |
|
0 commit comments