From df993dfaf4c6899c60c8991f16794337e1bd3689 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 24 Mar 2025 20:28:46 -0700 Subject: [PATCH 1/6] add spark transformation Signed-off-by: HaoXuAI --- sdk/python/feast/batch_feature_view.py | 6 +- .../feast/infra/compute_engines/__init__.py | 0 .../infra/compute_engines/spark/__init__.py | 0 .../infra/compute_engines/spark/config.py | 19 ++++ .../infra/compute_engines/spark/utils.py | 19 ++++ sdk/python/feast/stream_feature_view.py | 5 +- sdk/python/feast/transformation/base.py | 4 +- sdk/python/feast/transformation/factory.py | 1 + sdk/python/feast/transformation/mode.py | 1 + .../transformation/spark_transformation.py | 88 ++++++++++++++- .../test_pandas_transformation.py | 23 ++++ .../test_spark_transformation.py | 105 ++++++++++++++++++ 12 files changed, 260 insertions(+), 11 deletions(-) create mode 100644 sdk/python/feast/infra/compute_engines/__init__.py create mode 100644 sdk/python/feast/infra/compute_engines/spark/__init__.py create mode 100644 sdk/python/feast/infra/compute_engines/spark/config.py create mode 100644 sdk/python/feast/infra/compute_engines/spark/utils.py create mode 100644 sdk/python/tests/unit/transformation/test_pandas_transformation.py create mode 100644 sdk/python/tests/unit/transformation/test_spark_transformation.py diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index c66af0db18e..0db615f2cd5 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -2,7 +2,7 @@ import warnings from datetime import datetime, timedelta from types import FunctionType -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Callable, Any import dill @@ -61,7 +61,7 @@ class BatchFeatureView(FeatureView): owner: str timestamp_field: str materialization_intervals: List[Tuple[datetime, datetime]] - udf: Optional[FunctionType] + udf: Optional[Callable[[Any], Any]] udf_string: Optional[str] feature_transformation: Transformation @@ -78,7 +78,7 @@ def __init__( description: str = "", owner: str = "", schema: Optional[List[Field]] = None, - udf: Optional[FunctionType] = None, + udf: Optional[Callable[[Any], Any]], udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, ): diff --git a/sdk/python/feast/infra/compute_engines/__init__.py b/sdk/python/feast/infra/compute_engines/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/spark/__init__.py b/sdk/python/feast/infra/compute_engines/spark/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/spark/config.py b/sdk/python/feast/infra/compute_engines/spark/config.py new file mode 100644 index 00000000000..bef6efaf962 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/config.py @@ -0,0 +1,19 @@ +from typing import Optional, Dict + +from feast.repo_config import FeastConfigBaseModel +from pydantic import StrictStr + + +class SparkComputeConfig(FeastConfigBaseModel): + type: StrictStr = "spark" + """ Spark Compute type selector""" + + spark_conf: Optional[Dict[str, str]] = None + """ Configuration overlay for the spark session """ + # sparksession is not serializable and we dont want to pass it around as an argument + + staging_location: Optional[StrictStr] = None + """ Remote path for batch materialization jobs""" + + region: Optional[StrictStr] = None + """ AWS Region if applicable for s3-based staging locations""" diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py new file mode 100644 index 00000000000..ab1d3fd6455 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -0,0 +1,19 @@ +from typing import Optional, Dict + +from pyspark import SparkConf +from pyspark.sql import SparkSession + + +def get_or_create_new_spark_session( + spark_config: Optional[Dict[str, str]] = None +) -> SparkSession: + spark_session = SparkSession.getActiveSession() + if not spark_session: + spark_builder = SparkSession.builder + if spark_config: + spark_builder = spark_builder.config( + conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()]) + ) + + spark_session = spark_builder.getOrCreate() + return spark_session diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 42802993226..3f4e54937b8 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -151,8 +151,9 @@ def get_feature_transformation(self) -> Optional[Transformation]: if self.mode in ( TransformationMode.PANDAS, TransformationMode.PYTHON, - TransformationMode.SPARK, - ) or self.mode in ("pandas", "python", "spark"): + TransformationMode.SPARK_SQL, + TransformationMode.SPARK + ) or self.mode in ("pandas", "python", "spark_sql", "spark"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" ) diff --git a/sdk/python/feast/transformation/base.py b/sdk/python/feast/transformation/base.py index 7489e16be97..b02be0a6708 100644 --- a/sdk/python/feast/transformation/base.py +++ b/sdk/python/feast/transformation/base.py @@ -81,7 +81,7 @@ def __init__( description: str = "", owner: str = "", ): - self.mode = mode if isinstance(mode, str) else mode.value + self.mode = mode self.udf = udf self.udf_string = udf_string self.name = name @@ -99,7 +99,7 @@ def to_proto(self) -> Union[UserDefinedFunctionProto, SubstraitTransformationPro def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "Transformation": return Transformation(mode=self.mode, udf=self.udf, udf_string=self.udf_string) - def transform(self, inputs: Any) -> Any: + def transform(self, *inputs: Any) -> Any: raise NotImplementedError def transform_arrow(self, *args, **kwargs) -> Any: diff --git a/sdk/python/feast/transformation/factory.py b/sdk/python/feast/transformation/factory.py index 5097d71353a..50c3c665764 100644 --- a/sdk/python/feast/transformation/factory.py +++ b/sdk/python/feast/transformation/factory.py @@ -5,6 +5,7 @@ "pandas": "feast.transformation.pandas_transformation.PandasTransformation", "substrait": "feast.transformation.substrait_transformation.SubstraitTransformation", "sql": "feast.transformation.sql_transformation.SQLTransformation", + "spark_sql": "feast.transformation.spark_transformation.SparkTransformation", "spark": "feast.transformation.spark_transformation.SparkTransformation", } diff --git a/sdk/python/feast/transformation/mode.py b/sdk/python/feast/transformation/mode.py index 4bd5ddbe7a3..2b453477b3a 100644 --- a/sdk/python/feast/transformation/mode.py +++ b/sdk/python/feast/transformation/mode.py @@ -4,6 +4,7 @@ class TransformationMode(Enum): PYTHON = "python" PANDAS = "pandas" + SPARK_SQL = "spark_sql" SPARK = "spark" SQL = "sql" SUBSTRAIT = "substrait" diff --git a/sdk/python/feast/transformation/spark_transformation.py b/sdk/python/feast/transformation/spark_transformation.py index d288cf58b08..55040d8e114 100644 --- a/sdk/python/feast/transformation/spark_transformation.py +++ b/sdk/python/feast/transformation/spark_transformation.py @@ -1,11 +1,91 @@ -from typing import Any +from typing import Any, Union, Dict, Optional, cast + +import pandas as pd +import pyspark.sql from feast.transformation.base import Transformation +from feast.transformation.mode import TransformationMode +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session class SparkTransformation(Transformation): - def transform(self, inputs: Any) -> Any: - pass - def infer_features(self, *args, **kwargs) -> Any: + def __new__(cls, + mode: Union[TransformationMode, str], + udf: Any, + udf_string: str, + spark_config: Dict[str, Any] = {}, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs) -> "SparkTransformation": + instance = super(SparkTransformation, cls).__new__( + cls, + mode=mode, + spark_config=spark_config, + udf=udf, + udf_string=udf_string, + name=name, + tags=tags, + description=description, + owner=owner, + ) + return cast(SparkTransformation, instance) + + def __init__(self, + mode: Union[TransformationMode, str], + udf: Any, + udf_string: str, + spark_config: Dict[str, Any] = {}, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs): + super().__init__( + mode=mode, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ) + self.spark_session = get_or_create_new_spark_session(spark_config) + + def transform(self, + *inputs: Union[str, pd.DataFrame], + ) -> pd.DataFrame: + if self.mode == TransformationMode.SPARK_SQL: + return self._transform_spark_sql(*inputs) + else: + return self._transform_spark_udf(*inputs) + + @staticmethod + def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame, + name: str): + df_temp_view = f"feast_transformation_temp_view_{name}" + df.createOrReplaceTempView(df_temp_view) + return df_temp_view + + def _transform_spark_sql(self, + *inputs: Union[pyspark.sql.DataFrame, str] + ) -> pd.DataFrame: + inputs_str = [ + self._create_temp_view_for_dataframe(v, f"index_{i}") + if isinstance(v, pyspark.sql.DataFrame) else v + for i, v in enumerate(inputs) + ] + return self.spark_session.sql(self.udf(*inputs_str)) + + def _transform_spark_udf(self, + *inputs: Any) -> pd.DataFrame: + return self.udf(*inputs) + + def infer_features(self, + *args, + **kwargs) -> Any: pass diff --git a/sdk/python/tests/unit/transformation/test_pandas_transformation.py b/sdk/python/tests/unit/transformation/test_pandas_transformation.py new file mode 100644 index 00000000000..8a937d76c30 --- /dev/null +++ b/sdk/python/tests/unit/transformation/test_pandas_transformation.py @@ -0,0 +1,23 @@ +from feast.transformation.pandas_transformation import PandasTransformation +import pandas as pd + + +def pandas_udf(features_df: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["output1"] = features_df["feature1"] + df["output2"] = features_df["feature2"] + return df + + +def test_init_pandas_transformation(): + transformation = PandasTransformation( + udf=pandas_udf, + udf_string="udf1" + ) + features_df = pd.DataFrame.from_dict({ + "feature1": [1, 2], + "feature2": [2, 3] + }) + transformed_df = transformation.transform(features_df) + assert transformed_df["output1"].values[0] == 1 + assert transformed_df["output2"].values[1] == 3 diff --git a/sdk/python/tests/unit/transformation/test_spark_transformation.py b/sdk/python/tests/unit/transformation/test_spark_transformation.py new file mode 100644 index 00000000000..1b9565095a8 --- /dev/null +++ b/sdk/python/tests/unit/transformation/test_spark_transformation.py @@ -0,0 +1,105 @@ +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, regexp_replace +from unittest.mock import patch +from pyspark.testing.utils import assertDataFrameEqual + +from feast.transformation.spark_transformation import SparkTransformation +from feast.transformation.mode import TransformationMode +from feast.transformation.base import Transformation + + +def get_sample_df(spark): + sample_data = [{"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + {"name": "Bob T.", "age": 35}, + {"name": "Eve A.", "age": 28}] + df = spark.createDataFrame(sample_data) + return df + + +def get_expected_df(spark): + expected_data = [{"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + {"name": "Bob T.", "age": 35}, + {"name": "Eve A.", "age": 28}] + + expected_df = spark.createDataFrame(expected_data) + return expected_df + + +def remove_extra_spaces(df, + column_name): + df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " ")) + return df_transformed + + +def remove_extra_spaces_sql(df, + column_name): + sql = f""" + SELECT + age, + regexp_replace({column_name}, '\\s+', ' ') as {column_name} + FROM {df} + """ + return sql + + +@pytest.fixture +def spark_fixture(): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + yield spark + + +@patch( + "feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session" +) +def test_spark_transformation(spark_fixture): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + df = get_sample_df(spark) + + spark_transformation = Transformation( + mode=TransformationMode.SPARK, + udf=remove_extra_spaces, + udf_string="remove extra spaces", + ) + + transformed_df = spark_transformation.transform(df, "name") + expected_df = get_expected_df(spark) + assertDataFrameEqual(transformed_df, expected_df) + + +@patch( + "feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session" +) +def test_spark_transformation_init_transformation(spark_fixture): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + df = get_sample_df(spark) + + spark_transformation = SparkTransformation( + mode=TransformationMode.SPARK, + udf=remove_extra_spaces, + udf_string="remove extra spaces", + ) + + transformed_df = spark_transformation.transform(df, "name") + expected_df = get_expected_df(spark) + assertDataFrameEqual(transformed_df, expected_df) + + +@patch( + "feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session" +) +def test_spark_transformation_sql(spark_fixture): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + df = get_sample_df(spark) + + spark_transformation = SparkTransformation( + mode=TransformationMode.SPARK_SQL, + udf=remove_extra_spaces_sql, + udf_string="remove extra spaces", + ) + + transformed_df = spark_transformation.transform(df, "name") + expected_df = get_expected_df(spark) + assertDataFrameEqual(transformed_df, expected_df) From da67d7d8fed42ec78cadc10c7bb70d03feca21c2 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 24 Mar 2025 20:29:42 -0700 Subject: [PATCH 2/6] fix lint Signed-off-by: HaoXuAI --- sdk/python/feast/batch_feature_view.py | 3 +- .../infra/compute_engines/spark/config.py | 5 +- .../infra/compute_engines/spark/utils.py | 4 +- sdk/python/feast/stream_feature_view.py | 2 +- .../transformation/spark_transformation.py | 79 ++++++++++--------- .../test_pandas_transformation.py | 13 +-- .../test_spark_transformation.py | 49 ++++++------ 7 files changed, 75 insertions(+), 80 deletions(-) diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 0db615f2cd5..57d5aa1b07e 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -1,8 +1,7 @@ import functools import warnings from datetime import datetime, timedelta -from types import FunctionType -from typing import Dict, List, Optional, Tuple, Union, Callable, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import dill diff --git a/sdk/python/feast/infra/compute_engines/spark/config.py b/sdk/python/feast/infra/compute_engines/spark/config.py index bef6efaf962..070cf204dce 100644 --- a/sdk/python/feast/infra/compute_engines/spark/config.py +++ b/sdk/python/feast/infra/compute_engines/spark/config.py @@ -1,8 +1,9 @@ -from typing import Optional, Dict +from typing import Dict, Optional -from feast.repo_config import FeastConfigBaseModel from pydantic import StrictStr +from feast.repo_config import FeastConfigBaseModel + class SparkComputeConfig(FeastConfigBaseModel): type: StrictStr = "spark" diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index ab1d3fd6455..262876f9dbb 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,11 +1,11 @@ -from typing import Optional, Dict +from typing import Dict, Optional from pyspark import SparkConf from pyspark.sql import SparkSession def get_or_create_new_spark_session( - spark_config: Optional[Dict[str, str]] = None + spark_config: Optional[Dict[str, str]] = None, ) -> SparkSession: spark_session = SparkSession.getActiveSession() if not spark_session: diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 3f4e54937b8..2f134001a5a 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -152,7 +152,7 @@ def get_feature_transformation(self) -> Optional[Transformation]: TransformationMode.PANDAS, TransformationMode.PYTHON, TransformationMode.SPARK_SQL, - TransformationMode.SPARK + TransformationMode.SPARK, ) or self.mode in ("pandas", "python", "spark_sql", "spark"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" diff --git a/sdk/python/feast/transformation/spark_transformation.py b/sdk/python/feast/transformation/spark_transformation.py index 55040d8e114..ec94c2fe716 100644 --- a/sdk/python/feast/transformation/spark_transformation.py +++ b/sdk/python/feast/transformation/spark_transformation.py @@ -1,26 +1,27 @@ -from typing import Any, Union, Dict, Optional, cast +from typing import Any, Dict, Optional, Union, cast import pandas as pd import pyspark.sql +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session from feast.transformation.base import Transformation from feast.transformation.mode import TransformationMode -from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session class SparkTransformation(Transformation): - - def __new__(cls, - mode: Union[TransformationMode, str], - udf: Any, - udf_string: str, - spark_config: Dict[str, Any] = {}, - name: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, - description: str = "", - owner: str = "", - *args, - **kwargs) -> "SparkTransformation": + def __new__( + cls, + mode: Union[TransformationMode, str], + udf: Any, + udf_string: str, + spark_config: Dict[str, Any] = {}, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ) -> "SparkTransformation": instance = super(SparkTransformation, cls).__new__( cls, mode=mode, @@ -34,17 +35,19 @@ def __new__(cls, ) return cast(SparkTransformation, instance) - def __init__(self, - mode: Union[TransformationMode, str], - udf: Any, - udf_string: str, - spark_config: Dict[str, Any] = {}, - name: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, - description: str = "", - owner: str = "", - *args, - **kwargs): + def __init__( + self, + mode: Union[TransformationMode, str], + udf: Any, + udf_string: str, + spark_config: Dict[str, Any] = {}, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ): super().__init__( mode=mode, udf=udf, @@ -56,36 +59,34 @@ def __init__(self, ) self.spark_session = get_or_create_new_spark_session(spark_config) - def transform(self, - *inputs: Union[str, pd.DataFrame], - ) -> pd.DataFrame: + def transform( + self, + *inputs: Union[str, pd.DataFrame], + ) -> pd.DataFrame: if self.mode == TransformationMode.SPARK_SQL: return self._transform_spark_sql(*inputs) else: return self._transform_spark_udf(*inputs) @staticmethod - def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame, - name: str): + def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame, name: str): df_temp_view = f"feast_transformation_temp_view_{name}" df.createOrReplaceTempView(df_temp_view) return df_temp_view - def _transform_spark_sql(self, - *inputs: Union[pyspark.sql.DataFrame, str] - ) -> pd.DataFrame: + def _transform_spark_sql( + self, *inputs: Union[pyspark.sql.DataFrame, str] + ) -> pd.DataFrame: inputs_str = [ self._create_temp_view_for_dataframe(v, f"index_{i}") - if isinstance(v, pyspark.sql.DataFrame) else v + if isinstance(v, pyspark.sql.DataFrame) + else v for i, v in enumerate(inputs) ] return self.spark_session.sql(self.udf(*inputs_str)) - def _transform_spark_udf(self, - *inputs: Any) -> pd.DataFrame: + def _transform_spark_udf(self, *inputs: Any) -> pd.DataFrame: return self.udf(*inputs) - def infer_features(self, - *args, - **kwargs) -> Any: + def infer_features(self, *args, **kwargs) -> Any: pass diff --git a/sdk/python/tests/unit/transformation/test_pandas_transformation.py b/sdk/python/tests/unit/transformation/test_pandas_transformation.py index 8a937d76c30..d20204ceb93 100644 --- a/sdk/python/tests/unit/transformation/test_pandas_transformation.py +++ b/sdk/python/tests/unit/transformation/test_pandas_transformation.py @@ -1,6 +1,7 @@ -from feast.transformation.pandas_transformation import PandasTransformation import pandas as pd +from feast.transformation.pandas_transformation import PandasTransformation + def pandas_udf(features_df: pd.DataFrame) -> pd.DataFrame: df = pd.DataFrame() @@ -10,14 +11,8 @@ def pandas_udf(features_df: pd.DataFrame) -> pd.DataFrame: def test_init_pandas_transformation(): - transformation = PandasTransformation( - udf=pandas_udf, - udf_string="udf1" - ) - features_df = pd.DataFrame.from_dict({ - "feature1": [1, 2], - "feature2": [2, 3] - }) + transformation = PandasTransformation(udf=pandas_udf, udf_string="udf1") + features_df = pd.DataFrame.from_dict({"feature1": [1, 2], "feature2": [2, 3]}) transformed_df = transformation.transform(features_df) assert transformed_df["output1"].values[0] == 1 assert transformed_df["output2"].values[1] == 3 diff --git a/sdk/python/tests/unit/transformation/test_spark_transformation.py b/sdk/python/tests/unit/transformation/test_spark_transformation.py index 1b9565095a8..26426457c51 100644 --- a/sdk/python/tests/unit/transformation/test_spark_transformation.py +++ b/sdk/python/tests/unit/transformation/test_spark_transformation.py @@ -1,41 +1,46 @@ +from unittest.mock import patch + import pytest from pyspark.sql import SparkSession from pyspark.sql.functions import col, regexp_replace -from unittest.mock import patch from pyspark.testing.utils import assertDataFrameEqual -from feast.transformation.spark_transformation import SparkTransformation -from feast.transformation.mode import TransformationMode from feast.transformation.base import Transformation +from feast.transformation.mode import TransformationMode +from feast.transformation.spark_transformation import SparkTransformation def get_sample_df(spark): - sample_data = [{"name": "John D.", "age": 30}, - {"name": "Alice G.", "age": 25}, - {"name": "Bob T.", "age": 35}, - {"name": "Eve A.", "age": 28}] + sample_data = [ + {"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + {"name": "Bob T.", "age": 35}, + {"name": "Eve A.", "age": 28}, + ] df = spark.createDataFrame(sample_data) return df def get_expected_df(spark): - expected_data = [{"name": "John D.", "age": 30}, - {"name": "Alice G.", "age": 25}, - {"name": "Bob T.", "age": 35}, - {"name": "Eve A.", "age": 28}] + expected_data = [ + {"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + {"name": "Bob T.", "age": 35}, + {"name": "Eve A.", "age": 28}, + ] expected_df = spark.createDataFrame(expected_data) return expected_df -def remove_extra_spaces(df, - column_name): - df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " ")) +def remove_extra_spaces(df, column_name): + df_transformed = df.withColumn( + column_name, regexp_replace(col(column_name), "\\s+", " ") + ) return df_transformed -def remove_extra_spaces_sql(df, - column_name): +def remove_extra_spaces_sql(df, column_name): sql = f""" SELECT age, @@ -51,9 +56,7 @@ def spark_fixture(): yield spark -@patch( - "feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session" -) +@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") def test_spark_transformation(spark_fixture): spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() df = get_sample_df(spark) @@ -69,9 +72,7 @@ def test_spark_transformation(spark_fixture): assertDataFrameEqual(transformed_df, expected_df) -@patch( - "feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session" -) +@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") def test_spark_transformation_init_transformation(spark_fixture): spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() df = get_sample_df(spark) @@ -87,9 +88,7 @@ def test_spark_transformation_init_transformation(spark_fixture): assertDataFrameEqual(transformed_df, expected_df) -@patch( - "feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session" -) +@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") def test_spark_transformation_sql(spark_fixture): spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() df = get_sample_df(spark) From 0992fb284586b9417e95c8e04e0c3d0b56aa8503 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 24 Mar 2025 20:38:43 -0700 Subject: [PATCH 3/6] fix lint issue Signed-off-by: HaoXuAI --- .../tests/utils/ssl_certifcates_util.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/sdk/python/tests/utils/ssl_certifcates_util.py b/sdk/python/tests/utils/ssl_certifcates_util.py index 53a56e04f3d..53b9df3973c 100644 --- a/sdk/python/tests/utils/ssl_certifcates_util.py +++ b/sdk/python/tests/utils/ssl_certifcates_util.py @@ -8,7 +8,7 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import dh, dsa, ec, rsa from cryptography.x509 import load_pem_x509_certificate from cryptography.x509.oid import NameOID @@ -126,13 +126,33 @@ def create_ca_trust_store( private_key = serialization.load_pem_private_key( private_key_data, password=None, backend=default_backend() ) - # Check the public/private key match - if ( - private_key.public_key().public_numbers() - != public_cert.public_key().public_numbers() + private_pub = private_key.public_key() + cert_pub = public_cert.public_key() + + if isinstance( + private_pub, + ( + rsa.RSAPublicKey, + dsa.DSAPublicKey, + ec.EllipticCurvePublicKey, + dh.DHPublicKey, + ), + ) and isinstance( + cert_pub, + ( + rsa.RSAPublicKey, + dsa.DSAPublicKey, + ec.EllipticCurvePublicKey, + dh.DHPublicKey, + ), ): - raise ValueError( - "Public certificate does not match the private key." + if private_pub.public_numbers() != cert_pub.public_numbers(): + raise ValueError( + "Public certificate does not match the private key." + ) + else: + logger.warning( + "Key type does not support public_numbers(). Skipping strict public key match." ) # Step 4: Add the public certificate to the new trust store From 2577fd31aab1b3a5d36e13377c931c831eb3f38f Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 24 Mar 2025 21:07:00 -0700 Subject: [PATCH 4/6] fix unit test Signed-off-by: HaoXuAI --- .../tests/unit/transformation/test_spark_transformation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/tests/unit/transformation/test_spark_transformation.py b/sdk/python/tests/unit/transformation/test_spark_transformation.py index 26426457c51..8ee9d22bf28 100644 --- a/sdk/python/tests/unit/transformation/test_spark_transformation.py +++ b/sdk/python/tests/unit/transformation/test_spark_transformation.py @@ -44,7 +44,7 @@ def remove_extra_spaces_sql(df, column_name): sql = f""" SELECT age, - regexp_replace({column_name}, '\\s+', ' ') as {column_name} + regexp_replace({column_name}, '\\\\s+', ' ') as {column_name} FROM {df} """ return sql @@ -96,7 +96,7 @@ def test_spark_transformation_sql(spark_fixture): spark_transformation = SparkTransformation( mode=TransformationMode.SPARK_SQL, udf=remove_extra_spaces_sql, - udf_string="remove extra spaces", + udf_string="remove extra spaces sql", ) transformed_df = spark_transformation.transform(df, "name") From a92da11141249ff9c5fe65ce944b5f01189de127 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 24 Mar 2025 21:49:12 -0700 Subject: [PATCH 5/6] add doc Signed-off-by: HaoXuAI --- .../transformation/spark_transformation.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/sdk/python/feast/transformation/spark_transformation.py b/sdk/python/feast/transformation/spark_transformation.py index ec94c2fe716..3f1a1371c3e 100644 --- a/sdk/python/feast/transformation/spark_transformation.py +++ b/sdk/python/feast/transformation/spark_transformation.py @@ -9,6 +9,26 @@ class SparkTransformation(Transformation): + """ + SparkTransformation can be used to define a transformation using a Spark UDF or SQL query. + The current spark session will be used or a new one will be created if not available. + E.g.: + spark_transformation = SparkTransformation( + mode=TransformationMode.SPARK, + udf=remove_extra_spaces, + udf_string="remove extra spaces", + ) + OR + spark_transformation = Transformation( + mode=TransformationMode.SPARK_SQL, + udf=remove_extra_spaces_sql, + udf_string="remove extra spaces sql", + ) + OR + @transformation(mode=TransformationMode.SPARK) + def remove_extra_spaces_udf(df: pd.DataFrame) -> pd.DataFrame: + return df.assign(name=df['name'].str.replace('\s+', ' ')) + """ def __new__( cls, mode: Union[TransformationMode, str], @@ -22,6 +42,18 @@ def __new__( *args, **kwargs, ) -> "SparkTransformation": + """ + Creates a SparkTransformation + Args: + mode: (required) The mode of the transformation. Choose one from TransformationMode.SPARK or TransformationMode.SPARK_SQL. + udf: (required) The user-defined transformation function. + udf_string: (required) The string representation of the udf. The dill get source doesn't + spark_config: (optional) The spark configuration to use for the transformation. + name: (optional) The name of the transformation. + tags: (optional) Metadata tags for the transformation. + description: (optional) A description of the transformation. + owner: (optional) The owner of the transformation. + """ instance = super(SparkTransformation, cls).__new__( cls, mode=mode, From d8b77323e9e2c4f8f9ab1fa7e1f5d4f678ac84f0 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 24 Mar 2025 22:51:46 -0700 Subject: [PATCH 6/6] add doc Signed-off-by: HaoXuAI --- sdk/python/feast/transformation/spark_transformation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/transformation/spark_transformation.py b/sdk/python/feast/transformation/spark_transformation.py index 3f1a1371c3e..84d4c010c17 100644 --- a/sdk/python/feast/transformation/spark_transformation.py +++ b/sdk/python/feast/transformation/spark_transformation.py @@ -9,7 +9,7 @@ class SparkTransformation(Transformation): - """ + r""" SparkTransformation can be used to define a transformation using a Spark UDF or SQL query. The current spark session will be used or a new one will be created if not available. E.g.: @@ -29,6 +29,7 @@ class SparkTransformation(Transformation): def remove_extra_spaces_udf(df: pd.DataFrame) -> pd.DataFrame: return df.assign(name=df['name'].str.replace('\s+', ' ')) """ + def __new__( cls, mode: Union[TransformationMode, str],