diff --git a/python/feast_spark/pyspark/historical_feature_retrieval_job.py b/python/feast_spark/pyspark/historical_feature_retrieval_job.py index 82cb9579..6d3b9794 100644 --- a/python/feast_spark/pyspark/historical_feature_retrieval_job.py +++ b/python/feast_spark/pyspark/historical_feature_retrieval_job.py @@ -14,6 +14,7 @@ from pyspark.sql.functions import ( broadcast, col, + expr, monotonically_increasing_id, row_number, ) @@ -321,11 +322,16 @@ class FileDestination(NamedTuple): def _map_column(df: DataFrame, col_mapping: Dict[str, str]): source_to_alias_map = {v: k for k, v in col_mapping.items()} - projection = [ - col(col_name).alias(source_to_alias_map.get(col_name, col_name)) - for col_name in df.columns - ] - return df.select(projection) + projection = {} + + for col_name in df.columns + list(set(col_mapping) - set(df.columns)): + if col_name in source_to_alias_map: + # column rename + projection[source_to_alias_map.get(col_name)] = col(col_name) + else: + projection[col_name] = expr(col_mapping.get(col_name, col_name)) + + return df.select([c.alias(a) for a, c in projection.items()]) def as_of_join( diff --git a/python/tests/test_historical_feature_retrieval.py b/python/tests/test_historical_feature_retrieval.py index e0e8d4d0..ca116a40 100644 --- a/python/tests/test_historical_feature_retrieval.py +++ b/python/tests/test_historical_feature_retrieval.py @@ -47,7 +47,7 @@ def large_entity_csv_file(pytestconfig, spark): file_path = os.path.join(temp_dir, "large_entity") entity_schema = StructType( [ - StructField("customer_id", IntegerType()), + StructField("id", IntegerType()), StructField("event_timestamp", TimestampType()), ] )