Skip to content

Commit e7482af

Browse files
davidheryantozhilingc
andauthored
Update Python SDK so FeatureSet can import Schema from Tensorflow metadata (feast-dev#450)
* Add skeleton for update/get schema in FeatureSet * Add update_schema method to FeatureSet - Update Field, Feature and Entity class with fields from presence_constraints, shape_type and domain_info * Update error message when domain ref is missing from top level schema * Add more assertion in test_update_schema before updating schema * Fix conflicting versions in package requirements * Add export_schema method to export schema from FeatureSet * Add exporting of Tensorflow metadata schema from FeatureSet. - Update documentation for properties in Field - Deduplication refactoring in FeatureSet * Remove changes to mypy generated codes * Revert changes to packages version in requirements-ci and setup.py They are not necessary for now and to avoid unexpected breaking changes. * Remove 'schema' param in 'from_proto' method in Entity and Feature. In import_tfx_schema method, the domain info is first made inline so there is no need to have schema level domain info when updating Feast Entity and Feature. Also added documentation to setter property methods in Field.py * Fix rebase errors, apply black * Remove unnecessary imports Co-authored-by: zhilingc <zhiling.c@go-jek.com>
1 parent e4f8fe9 commit e7482af

File tree

13 files changed

+917
-3450
lines changed

13 files changed

+917
-3450
lines changed

sdk/python/feast/entity.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,26 @@ def to_proto(self) -> EntityProto:
2929
Returns EntitySpec object
3030
"""
3131
value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name)
32-
return EntityProto(name=self.name, value_type=value_type)
32+
return EntityProto(
33+
name=self.name,
34+
value_type=value_type,
35+
presence=self.presence,
36+
group_presence=self.group_presence,
37+
shape=self.shape,
38+
value_count=self.value_count,
39+
domain=self.domain,
40+
int_domain=self.int_domain,
41+
float_domain=self.float_domain,
42+
string_domain=self.string_domain,
43+
bool_domain=self.bool_domain,
44+
struct_domain=self.struct_domain,
45+
natural_language_domain=self.natural_language_domain,
46+
image_domain=self.image_domain,
47+
mid_domain=self.mid_domain,
48+
url_domain=self.url_domain,
49+
time_domain=self.time_domain,
50+
time_of_day_domain=self.time_of_day_domain,
51+
)
3352

3453
@classmethod
3554
def from_proto(cls, entity_proto: EntityProto):
@@ -42,4 +61,8 @@ def from_proto(cls, entity_proto: EntityProto):
4261
Returns:
4362
Entity object
4463
"""
45-
return cls(name=entity_proto.name, dtype=ValueType(entity_proto.value_type))
64+
entity = cls(name=entity_proto.name, dtype=ValueType(entity_proto.value_type))
65+
entity.update_presence_constraints(entity_proto)
66+
entity.update_shape_type(entity_proto)
67+
entity.update_domain_info(entity_proto)
68+
return entity

sdk/python/feast/feature.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,41 @@ class Feature(Field):
2424
def to_proto(self) -> FeatureProto:
2525
"""Converts Feature object to its Protocol Buffer representation"""
2626
value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name)
27-
return FeatureProto(name=self.name, value_type=value_type)
27+
return FeatureProto(
28+
name=self.name,
29+
value_type=value_type,
30+
presence=self.presence,
31+
group_presence=self.group_presence,
32+
shape=self.shape,
33+
value_count=self.value_count,
34+
domain=self.domain,
35+
int_domain=self.int_domain,
36+
float_domain=self.float_domain,
37+
string_domain=self.string_domain,
38+
bool_domain=self.bool_domain,
39+
struct_domain=self.struct_domain,
40+
natural_language_domain=self.natural_language_domain,
41+
image_domain=self.image_domain,
42+
mid_domain=self.mid_domain,
43+
url_domain=self.url_domain,
44+
time_domain=self.time_domain,
45+
time_of_day_domain=self.time_of_day_domain,
46+
)
2847

2948
@classmethod
3049
def from_proto(cls, feature_proto: FeatureProto):
31-
"""Converts Protobuf Feature to its SDK equivalent"""
32-
return cls(name=feature_proto.name, dtype=ValueType(feature_proto.value_type))
50+
"""
51+
52+
Args:
53+
feature_proto: FeatureSpec protobuf object
54+
55+
Returns:
56+
Feature object
57+
"""
58+
feature = cls(
59+
name=feature_proto.name, dtype=ValueType(feature_proto.value_type)
60+
)
61+
feature.update_presence_constraints(feature_proto)
62+
feature.update_shape_type(feature_proto)
63+
feature.update_domain_info(feature_proto)
64+
return feature

sdk/python/feast/feature_set.py

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
14+
import warnings
1615
from collections import OrderedDict
17-
from typing import Dict, List, Optional
16+
from typing import Dict
17+
from typing import List, Optional
1818

1919
import pandas as pd
2020
import pyarrow as pa
2121
from google.protobuf import json_format
2222
from google.protobuf.duration_pb2 import Duration
2323
from google.protobuf.json_format import MessageToJson
24+
from google.protobuf.message import Message
2425
from pandas.api.types import is_datetime64_ns_dtype
2526
from pyarrow.lib import TimestampType
27+
from tensorflow_metadata.proto.v0 import schema_pb2
2628

2729
from feast.core.FeatureSet_pb2 import FeatureSet as FeatureSetProto
2830
from feast.core.FeatureSet_pb2 import FeatureSetMeta as FeatureSetMetaProto
@@ -657,6 +659,93 @@ def is_valid(self):
657659
if len(self.entities) == 0:
658660
raise ValueError(f"No entities found in feature set {self.name}")
659661

662+
def import_tfx_schema(self, schema: schema_pb2.Schema):
663+
"""
664+
Updates presence_constraints, shape_type and domain_info for all fields
665+
(features and entities) in the FeatureSet from schema in the Tensorflow metadata.
666+
667+
Args:
668+
schema: Schema from Tensorflow metadata
669+
670+
Returns:
671+
None
672+
673+
"""
674+
_make_tfx_schema_domain_info_inline(schema)
675+
for feature_from_tfx_schema in schema.feature:
676+
if feature_from_tfx_schema.name in self._fields.keys():
677+
field = self._fields[feature_from_tfx_schema.name]
678+
field.update_presence_constraints(feature_from_tfx_schema)
679+
field.update_shape_type(feature_from_tfx_schema)
680+
field.update_domain_info(feature_from_tfx_schema)
681+
else:
682+
warnings.warn(
683+
f"The provided schema contains feature name '{feature_from_tfx_schema.name}' "
684+
f"that does not exist in the FeatureSet '{self.name}' in Feast"
685+
)
686+
687+
def export_tfx_schema(self) -> schema_pb2.Schema:
688+
"""
689+
Create a Tensorflow metadata schema from a FeatureSet.
690+
691+
Returns:
692+
Tensorflow metadata schema.
693+
694+
"""
695+
schema = schema_pb2.Schema()
696+
697+
# List of attributes to copy from fields in the FeatureSet to feature in
698+
# Tensorflow metadata schema where the attribute name is the same.
699+
attributes_to_copy_from_field_to_feature = [
700+
"name",
701+
"presence",
702+
"group_presence",
703+
"shape",
704+
"value_count",
705+
"domain",
706+
"int_domain",
707+
"float_domain",
708+
"string_domain",
709+
"bool_domain",
710+
"struct_domain",
711+
"_natural_language_domain",
712+
"image_domain",
713+
"mid_domain",
714+
"url_domain",
715+
"time_domain",
716+
"time_of_day_domain",
717+
]
718+
719+
for _, field in self._fields.items():
720+
feature = schema_pb2.Feature()
721+
for attr in attributes_to_copy_from_field_to_feature:
722+
if getattr(field, attr) is None:
723+
# This corresponds to an unset member in the proto Oneof field.
724+
continue
725+
if issubclass(type(getattr(feature, attr)), Message):
726+
# Proto message field to copy is an "embedded" field, so MergeFrom()
727+
# method must be used.
728+
getattr(feature, attr).MergeFrom(getattr(field, attr))
729+
elif issubclass(type(getattr(feature, attr)), (int, str, bool)):
730+
# Proto message field is a simple Python type, so setattr()
731+
# can be used.
732+
setattr(feature, attr, getattr(field, attr))
733+
else:
734+
warnings.warn(
735+
f"Attribute '{attr}' cannot be copied from Field "
736+
f"'{field.name}' in FeatureSet '{self.name}' to a "
737+
f"Feature in the Tensorflow metadata schema, because"
738+
f"the type is neither a Protobuf message or Python "
739+
f"int, str and bool"
740+
)
741+
# "type" attr is handled separately because the attribute name is different
742+
# ("dtype" in field and "type" in Feature) and "type" in Feature is only
743+
# a subset of "dtype".
744+
feature.type = field.dtype.to_tfx_schema_feature_type()
745+
schema.feature.append(feature)
746+
747+
return schema
748+
660749
@classmethod
661750
def from_yaml(cls, yml: str):
662751
"""
@@ -855,6 +944,40 @@ def __hash__(self):
855944
return hash(repr(self))
856945

857946

947+
def _make_tfx_schema_domain_info_inline(schema: schema_pb2.Schema) -> None:
948+
"""
949+
Copy top level domain info defined at schema level into inline definition.
950+
One use case is when importing domain info from Tensorflow metadata schema
951+
into Feast features. Feast features do not have access to schema level information
952+
so the domain info needs to be inline.
953+
954+
Args:
955+
schema: Tensorflow metadata schema
956+
957+
Returns: None
958+
"""
959+
# Reference to domains defined at schema level
960+
domain_ref_to_string_domain = {d.name: d for d in schema.string_domain}
961+
domain_ref_to_float_domain = {d.name: d for d in schema.float_domain}
962+
domain_ref_to_int_domain = {d.name: d for d in schema.int_domain}
963+
964+
# With the reference, it is safe to remove the domains defined at schema level
965+
del schema.string_domain[:]
966+
del schema.float_domain[:]
967+
del schema.int_domain[:]
968+
969+
for feature in schema.feature:
970+
domain_info_case = feature.WhichOneof("domain_info")
971+
if domain_info_case == "domain":
972+
domain_ref = feature.domain
973+
if domain_ref in domain_ref_to_string_domain:
974+
feature.string_domain.MergeFrom(domain_ref_to_string_domain[domain_ref])
975+
elif domain_ref in domain_ref_to_float_domain:
976+
feature.float_domain.MergeFrom(domain_ref_to_float_domain[domain_ref])
977+
elif domain_ref in domain_ref_to_int_domain:
978+
feature.int_domain.MergeFrom(domain_ref_to_int_domain[domain_ref])
979+
980+
858981
def _infer_pd_column_type(column, series, rows_to_sample):
859982
dtype = None
860983
sample_count = 0

0 commit comments

Comments
 (0)