diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index fa5a7bd6208..124ce65ff90 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -4,6 +4,7 @@ import pandas as pd from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql import functions as F +from pyspark.sql.pandas.types import from_arrow_schema from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation @@ -80,7 +81,14 @@ def execute(self, context: ExecutionContext) -> DAGValue: if isinstance(retrieval_job, SparkRetrievalJob): spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() else: - spark_df = self.spark_session.createDataFrame(retrieval_job.to_arrow()) + arrow_table = retrieval_job.to_arrow() + if arrow_table.num_rows == 0: + spark_schema = from_arrow_schema(arrow_table.schema) + spark_df = self.spark_session.createDataFrame( + self.spark_session.sparkContext.emptyRDD(), schema=spark_schema + ) + else: + spark_df = self.spark_session.createDataFrame(arrow_table.to_pandas()) return DAGValue( data=spark_df,