Skip to content

Commit be3d85c

Browse files
authored
feat: Spark Transformation (feast-dev#5185)
* add spark transformation Signed-off-by: HaoXuAI <sduxuhao@gmail.com> * fix lint Signed-off-by: HaoXuAI <sduxuhao@gmail.com> * fix lint issue Signed-off-by: HaoXuAI <sduxuhao@gmail.com> * fix unit test Signed-off-by: HaoXuAI <sduxuhao@gmail.com> * add doc Signed-off-by: HaoXuAI <sduxuhao@gmail.com> * add doc Signed-off-by: HaoXuAI <sduxuhao@gmail.com> --------- Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
1 parent caa7c61 commit be3d85c

13 files changed

Lines changed: 314 additions & 17 deletions

File tree

sdk/python/feast/batch_feature_view.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import functools
22
import warnings
33
from datetime import datetime, timedelta
4-
from types import FunctionType
5-
from typing import Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
65

76
import dill
87

@@ -61,7 +60,7 @@ class BatchFeatureView(FeatureView):
6160
owner: str
6261
timestamp_field: str
6362
materialization_intervals: List[Tuple[datetime, datetime]]
64-
udf: Optional[FunctionType]
63+
udf: Optional[Callable[[Any], Any]]
6564
udf_string: Optional[str]
6665
feature_transformation: Transformation
6766

@@ -78,7 +77,7 @@ def __init__(
7877
description: str = "",
7978
owner: str = "",
8079
schema: Optional[List[Field]] = None,
81-
udf: Optional[FunctionType] = None,
80+
udf: Optional[Callable[[Any], Any]],
8281
udf_string: Optional[str] = "",
8382
feature_transformation: Optional[Transformation] = None,
8483
):

sdk/python/feast/infra/compute_engines/__init__.py

Whitespace-only changes.

sdk/python/feast/infra/compute_engines/spark/__init__.py

Whitespace-only changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Dict, Optional
2+
3+
from pydantic import StrictStr
4+
5+
from feast.repo_config import FeastConfigBaseModel
6+
7+
8+
class SparkComputeConfig(FeastConfigBaseModel):
9+
type: StrictStr = "spark"
10+
""" Spark Compute type selector"""
11+
12+
spark_conf: Optional[Dict[str, str]] = None
13+
""" Configuration overlay for the spark session """
14+
# sparksession is not serializable and we dont want to pass it around as an argument
15+
16+
staging_location: Optional[StrictStr] = None
17+
""" Remote path for batch materialization jobs"""
18+
19+
region: Optional[StrictStr] = None
20+
""" AWS Region if applicable for s3-based staging locations"""
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Dict, Optional
2+
3+
from pyspark import SparkConf
4+
from pyspark.sql import SparkSession
5+
6+
7+
def get_or_create_new_spark_session(
8+
spark_config: Optional[Dict[str, str]] = None,
9+
) -> SparkSession:
10+
spark_session = SparkSession.getActiveSession()
11+
if not spark_session:
12+
spark_builder = SparkSession.builder
13+
if spark_config:
14+
spark_builder = spark_builder.config(
15+
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
16+
)
17+
18+
spark_session = spark_builder.getOrCreate()
19+
return spark_session

sdk/python/feast/stream_feature_view.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ def get_feature_transformation(self) -> Optional[Transformation]:
151151
if self.mode in (
152152
TransformationMode.PANDAS,
153153
TransformationMode.PYTHON,
154+
TransformationMode.SPARK_SQL,
154155
TransformationMode.SPARK,
155-
) or self.mode in ("pandas", "python", "spark"):
156+
) or self.mode in ("pandas", "python", "spark_sql", "spark"):
156157
return Transformation(
157158
mode=self.mode, udf=self.udf, udf_string=self.udf_string or ""
158159
)

sdk/python/feast/transformation/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
description: str = "",
8282
owner: str = "",
8383
):
84-
self.mode = mode if isinstance(mode, str) else mode.value
84+
self.mode = mode
8585
self.udf = udf
8686
self.udf_string = udf_string
8787
self.name = name
@@ -99,7 +99,7 @@ def to_proto(self) -> Union[UserDefinedFunctionProto, SubstraitTransformationPro
9999
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "Transformation":
100100
return Transformation(mode=self.mode, udf=self.udf, udf_string=self.udf_string)
101101

102-
def transform(self, inputs: Any) -> Any:
102+
def transform(self, *inputs: Any) -> Any:
103103
raise NotImplementedError
104104

105105
def transform_arrow(self, *args, **kwargs) -> Any:

sdk/python/feast/transformation/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"pandas": "feast.transformation.pandas_transformation.PandasTransformation",
66
"substrait": "feast.transformation.substrait_transformation.SubstraitTransformation",
77
"sql": "feast.transformation.sql_transformation.SQLTransformation",
8+
"spark_sql": "feast.transformation.spark_transformation.SparkTransformation",
89
"spark": "feast.transformation.spark_transformation.SparkTransformation",
910
}
1011

sdk/python/feast/transformation/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
class TransformationMode(Enum):
55
PYTHON = "python"
66
PANDAS = "pandas"
7+
SPARK_SQL = "spark_sql"
78
SPARK = "spark"
89
SQL = "sql"
910
SUBSTRAIT = "substrait"
Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,125 @@
1-
from typing import Any
1+
from typing import Any, Dict, Optional, Union, cast
22

3+
import pandas as pd
4+
import pyspark.sql
5+
6+
from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session
37
from feast.transformation.base import Transformation
8+
from feast.transformation.mode import TransformationMode
49

510

611
class SparkTransformation(Transformation):
7-
def transform(self, inputs: Any) -> Any:
8-
pass
12+
r"""
13+
SparkTransformation can be used to define a transformation using a Spark UDF or SQL query.
14+
The current spark session will be used or a new one will be created if not available.
15+
E.g.:
16+
spark_transformation = SparkTransformation(
17+
mode=TransformationMode.SPARK,
18+
udf=remove_extra_spaces,
19+
udf_string="remove extra spaces",
20+
)
21+
OR
22+
spark_transformation = Transformation(
23+
mode=TransformationMode.SPARK_SQL,
24+
udf=remove_extra_spaces_sql,
25+
udf_string="remove extra spaces sql",
26+
)
27+
OR
28+
@transformation(mode=TransformationMode.SPARK)
29+
def remove_extra_spaces_udf(df: pd.DataFrame) -> pd.DataFrame:
30+
return df.assign(name=df['name'].str.replace('\s+', ' '))
31+
"""
32+
33+
def __new__(
34+
cls,
35+
mode: Union[TransformationMode, str],
36+
udf: Any,
37+
udf_string: str,
38+
spark_config: Dict[str, Any] = {},
39+
name: Optional[str] = None,
40+
tags: Optional[Dict[str, str]] = None,
41+
description: str = "",
42+
owner: str = "",
43+
*args,
44+
**kwargs,
45+
) -> "SparkTransformation":
46+
"""
47+
Creates a SparkTransformation
48+
Args:
49+
mode: (required) The mode of the transformation. Choose one from TransformationMode.SPARK or TransformationMode.SPARK_SQL.
50+
udf: (required) The user-defined transformation function.
51+
udf_string: (required) The string representation of the udf. The dill get source doesn't
52+
spark_config: (optional) The spark configuration to use for the transformation.
53+
name: (optional) The name of the transformation.
54+
tags: (optional) Metadata tags for the transformation.
55+
description: (optional) A description of the transformation.
56+
owner: (optional) The owner of the transformation.
57+
"""
58+
instance = super(SparkTransformation, cls).__new__(
59+
cls,
60+
mode=mode,
61+
spark_config=spark_config,
62+
udf=udf,
63+
udf_string=udf_string,
64+
name=name,
65+
tags=tags,
66+
description=description,
67+
owner=owner,
68+
)
69+
return cast(SparkTransformation, instance)
70+
71+
def __init__(
72+
self,
73+
mode: Union[TransformationMode, str],
74+
udf: Any,
75+
udf_string: str,
76+
spark_config: Dict[str, Any] = {},
77+
name: Optional[str] = None,
78+
tags: Optional[Dict[str, str]] = None,
79+
description: str = "",
80+
owner: str = "",
81+
*args,
82+
**kwargs,
83+
):
84+
super().__init__(
85+
mode=mode,
86+
udf=udf,
87+
name=name,
88+
udf_string=udf_string,
89+
tags=tags,
90+
description=description,
91+
owner=owner,
92+
)
93+
self.spark_session = get_or_create_new_spark_session(spark_config)
94+
95+
def transform(
96+
self,
97+
*inputs: Union[str, pd.DataFrame],
98+
) -> pd.DataFrame:
99+
if self.mode == TransformationMode.SPARK_SQL:
100+
return self._transform_spark_sql(*inputs)
101+
else:
102+
return self._transform_spark_udf(*inputs)
103+
104+
@staticmethod
105+
def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame, name: str):
106+
df_temp_view = f"feast_transformation_temp_view_{name}"
107+
df.createOrReplaceTempView(df_temp_view)
108+
return df_temp_view
109+
110+
def _transform_spark_sql(
111+
self, *inputs: Union[pyspark.sql.DataFrame, str]
112+
) -> pd.DataFrame:
113+
inputs_str = [
114+
self._create_temp_view_for_dataframe(v, f"index_{i}")
115+
if isinstance(v, pyspark.sql.DataFrame)
116+
else v
117+
for i, v in enumerate(inputs)
118+
]
119+
return self.spark_session.sql(self.udf(*inputs_str))
120+
121+
def _transform_spark_udf(self, *inputs: Any) -> pd.DataFrame:
122+
return self.udf(*inputs)
9123

10124
def infer_features(self, *args, **kwargs) -> Any:
11125
pass

0 commit comments

Comments
 (0)