1+ from types import FunctionType
12from typing import Any , Dict , List
23
4+ import dill
35import pandas as pd
46import pyarrow
57import pyarrow .substrait as substrait # type: ignore # noqa
1618
1719
1820class SubstraitTransformation :
19- def __init__ (self , substrait_plan : bytes ):
21+ def __init__ (self , substrait_plan : bytes , ibis_function : FunctionType ):
2022 """
2123 Creates an SubstraitTransformation object.
2224
2325 Args:
2426 substrait_plan: The user-provided substrait plan.
27+ ibis_function: The user-provided ibis function.
2528 """
2629 self .substrait_plan = substrait_plan
30+ self .ibis_function = ibis_function
2731
2832 def transform (self , df : pd .DataFrame ) -> pd .DataFrame :
2933 def table_provider (names , schema : pyarrow .Schema ):
@@ -34,13 +38,22 @@ def table_provider(names, schema: pyarrow.Schema):
3438 ).read_all ()
3539 return table .to_pandas ()
3640
37- def transform_arrow (self , pa_table : pyarrow .Table ) -> pyarrow .Table :
41+ def transform_ibis (self , table ):
42+ return self .ibis_function (table )
43+
44+ def transform_arrow (
45+ self , pa_table : pyarrow .Table , features : List [Field ] = []
46+ ) -> pyarrow .Table :
3847 def table_provider (names , schema : pyarrow .Schema ):
3948 return pa_table .select (schema .names )
4049
4150 table : pyarrow .Table = pyarrow .substrait .run_query (
4251 self .substrait_plan , table_provider = table_provider
4352 ).read_all ()
53+
54+ if features :
55+ table = table .select ([f .name for f in features ])
56+
4457 return table
4558
4659 def infer_features (self , random_input : Dict [str , List [Any ]]) -> List [Field ]:
@@ -55,6 +68,7 @@ def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
5568 ),
5669 )
5770 for f , dt in zip (output_df .columns , output_df .dtypes )
71+ if f not in random_input
5872 ]
5973
6074 def __eq__ (self , other ):
@@ -66,18 +80,26 @@ def __eq__(self, other):
6680 if not super ().__eq__ (other ):
6781 return False
6882
69- return self .substrait_plan == other .substrait_plan
83+ return (
84+ self .substrait_plan == other .substrait_plan
85+ and self .ibis_function .__code__ .co_code
86+ == other .ibis_function .__code__ .co_code
87+ )
7088
7189 def to_proto (self ) -> SubstraitTransformationProto :
72- return SubstraitTransformationProto (substrait_plan = self .substrait_plan )
90+ return SubstraitTransformationProto (
91+ substrait_plan = self .substrait_plan ,
92+ ibis_function = dill .dumps (self .ibis_function , recurse = True ),
93+ )
7394
7495 @classmethod
7596 def from_proto (
7697 cls ,
7798 substrait_transformation_proto : SubstraitTransformationProto ,
7899 ):
79100 return SubstraitTransformation (
80- substrait_plan = substrait_transformation_proto .substrait_plan
101+ substrait_plan = substrait_transformation_proto .substrait_plan ,
102+ ibis_function = dill .loads (substrait_transformation_proto .ibis_function ),
81103 )
82104
83105 @classmethod
@@ -91,7 +113,7 @@ def from_ibis(cls, user_function, sources):
91113 input_fields = []
92114
93115 for s in sources :
94- fields = s .projection .features if isinstance (s , FeatureView ) else s .features
116+ fields = s .projection .features if isinstance (s , FeatureView ) else s .schema
95117
96118 input_fields .extend (
97119 [
@@ -108,5 +130,6 @@ def from_ibis(cls, user_function, sources):
108130 expr = user_function (ibis .table (input_fields , "t" ))
109131
110132 return SubstraitTransformation (
111- substrait_plan = compiler .compile (expr ).SerializeToString ()
133+ substrait_plan = compiler .compile (expr ).SerializeToString (),
134+ ibis_function = user_function ,
112135 )
0 commit comments