Skip to content

Commit 05522ce

Browse files
committed
feat: Fix Map/Dict support and implement schema validation
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent 16696b8 commit 05522ce

File tree

15 files changed

+564
-40
lines changed

15 files changed

+564
-40
lines changed

sdk/python/feast/batch_feature_view.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
feature_transformation: Optional[Transformation] = None,
9898
batch_engine: Optional[Dict[str, Any]] = None,
9999
aggregations: Optional[List[Aggregation]] = None,
100+
enable_validation: bool = False,
100101
):
101102
if not flags_helper.is_test():
102103
warnings.warn(
@@ -136,6 +137,7 @@ def __init__(
136137
source=source, # type: ignore[arg-type]
137138
sink_source=sink_source,
138139
mode=mode,
140+
enable_validation=enable_validation,
139141
)
140142

141143
def get_feature_transformation(self) -> Optional[Transformation]:
@@ -169,6 +171,7 @@ def batch_feature_view(
169171
description: str = "",
170172
owner: str = "",
171173
schema: Optional[List[Field]] = None,
174+
enable_validation: bool = False,
172175
):
173176
"""
174177
Creates a BatchFeatureView object with the given user-defined function (UDF) as the transformation.
@@ -199,6 +202,7 @@ def decorator(user_function):
199202
schema=schema,
200203
udf=user_function,
201204
udf_string=udf_string,
205+
enable_validation=enable_validation,
202206
)
203207
functools.update_wrapper(wrapper=batch_feature_view_obj, wrapped=user_function)
204208
return batch_feature_view_obj

sdk/python/feast/feature_view.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class FeatureView(BaseFeatureView):
107107
owner: str
108108
materialization_intervals: List[Tuple[datetime, datetime]]
109109
mode: Optional[Union["TransformationMode", str]]
110+
enable_validation: bool
110111

111112
def __init__(
112113
self,
@@ -123,6 +124,7 @@ def __init__(
123124
tags: Optional[Dict[str, str]] = None,
124125
owner: str = "",
125126
mode: Optional[Union["TransformationMode", str]] = None,
127+
enable_validation: bool = False,
126128
):
127129
"""
128130
Creates a FeatureView object.
@@ -148,11 +150,14 @@ def __init__(
148150
primary maintainer.
149151
mode (optional): The transformation mode for feature transformations. Only meaningful
150152
when transformations are applied. Choose from TransformationMode enum values.
153+
enable_validation (optional): If True, enables schema validation during materialization
154+
to check that data conforms to the declared feature types. Default is False.
151155
152156
Raises:
153157
ValueError: A field mapping conflicts with an Entity or a Feature.
154158
"""
155159
self.name = name
160+
self.enable_validation = enable_validation
156161
self.entities = [e.name for e in entities] if entities else [DUMMY_ENTITY_NAME]
157162
self.ttl = ttl
158163
schema = schema or []
@@ -457,13 +462,17 @@ def to_proto_spec(
457462
else self.mode
458463
)
459464

465+
tags = dict(self.tags) if self.tags else {}
466+
if self.enable_validation:
467+
tags["feast:enable_validation"] = "true"
468+
460469
return FeatureViewSpecProto(
461470
name=self.name,
462471
entities=self.entities,
463472
entity_columns=[field.to_proto() for field in self.entity_columns],
464473
features=[feature.to_proto() for feature in self.features],
465474
description=self.description,
466-
tags=self.tags,
475+
tags=tags,
467476
owner=self.owner,
468477
ttl=(ttl_duration if ttl_duration is not None else None),
469478
online=self.online,
@@ -642,6 +651,13 @@ def _from_proto_internal(
642651
f"Entities: {feature_view.entities} vs Entity Columns: {feature_view.entity_columns}"
643652
)
644653

654+
# Restore enable_validation from well-known tag.
655+
proto_tags = dict(feature_view_proto.spec.tags)
656+
feature_view.enable_validation = (
657+
proto_tags.pop("feast:enable_validation", "false").lower() == "true"
658+
)
659+
feature_view.tags = proto_tags
660+
645661
# FeatureViewProjections are not saved in the FeatureView proto.
646662
# Create the default projection.
647663
feature_view.projection = FeatureViewProjection.from_feature_view_definition(

sdk/python/feast/infra/compute_engines/local/feature_builder.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Union
23

34
from feast.aggregation import aggregation_specs_to_agg_ops
@@ -16,6 +17,9 @@
1617
LocalValidationNode,
1718
)
1819
from feast.infra.registry.base_registry import BaseRegistry
20+
from feast.types import from_feast_to_pyarrow_type
21+
22+
logger = logging.getLogger(__name__)
1923

2024

2125
class LocalFeatureBuilder(FeatureBuilder):
@@ -88,7 +92,24 @@ def build_transformation_node(self, view, input_nodes):
8892
return node
8993

9094
def build_validation_node(self, view, input_node):
91-
validation_config = view.validation_config
95+
validation_config = getattr(view, "validation_config", None) or {}
96+
97+
if not validation_config.get("columns") and hasattr(view, "features"):
98+
columns = {}
99+
for feature in view.features:
100+
try:
101+
columns[feature.name] = from_feast_to_pyarrow_type(feature.dtype)
102+
except (ValueError, KeyError):
103+
logger.debug(
104+
"Could not resolve PyArrow type for feature '%s' "
105+
"(dtype=%s), skipping type check for this column.",
106+
feature.name,
107+
feature.dtype,
108+
)
109+
columns[feature.name] = None
110+
if columns:
111+
validation_config = {**validation_config, "columns": columns}
112+
92113
node = LocalValidationNode(
93114
"validate", validation_config, self.backend, inputs=[input_node]
94115
)

sdk/python/feast/infra/compute_engines/local/nodes.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from datetime import datetime, timedelta
23
from typing import List, Optional, Union
34

@@ -19,6 +20,8 @@
1920
)
2021
from feast.utils import _convert_arrow_to_proto
2122

23+
logger = logging.getLogger(__name__)
24+
2225
ENTITY_TS_ALIAS = "__entity_event_timestamp"
2326

2427

@@ -236,15 +239,52 @@ def __init__(
236239

237240
def execute(self, context: ExecutionContext) -> ArrowTableValue:
238241
input_table = self.get_single_table(context).data
239-
df = self.backend.from_arrow(input_table)
240-
# Placeholder for actual validation logic
242+
241243
if self.validation_config:
242-
print(f"[Validation: {self.name}] Passed.")
243-
result = self.backend.to_arrow(df)
244-
output = ArrowTableValue(result)
244+
self._validate_schema(input_table)
245+
246+
output = ArrowTableValue(input_table)
245247
context.node_outputs[self.name] = output
246248
return output
247249

250+
def _validate_schema(self, table: pa.Table):
251+
"""Validate that the input table conforms to the expected schema.
252+
253+
Checks that all expected columns are present and that their types
254+
are compatible with the declared Feast types. Logs warnings for
255+
type mismatches but only raises on missing columns.
256+
"""
257+
expected_columns = self.validation_config.get("columns", {})
258+
if not expected_columns:
259+
logger.debug(
260+
"[Validation: %s] No column schema to validate against.",
261+
self.name,
262+
)
263+
return
264+
265+
actual_columns = set(table.column_names)
266+
expected_names = set(expected_columns.keys())
267+
268+
missing = expected_names - actual_columns
269+
if missing:
270+
raise ValueError(
271+
f"[Validation: {self.name}] Missing expected columns: {missing}. "
272+
f"Actual columns: {sorted(actual_columns)}"
273+
)
274+
275+
for col_name, expected_type in expected_columns.items():
276+
actual_type = table.schema.field(col_name).type
277+
if expected_type is not None and actual_type != expected_type:
278+
logger.warning(
279+
"[Validation: %s] Column '%s' type mismatch: expected %s, got %s",
280+
self.name,
281+
col_name,
282+
expected_type,
283+
actual_type,
284+
)
285+
286+
logger.debug("[Validation: %s] Schema validation passed.", self.name)
287+
248288

249289
class LocalOutputNode(LocalNode):
250290
def __init__(

sdk/python/feast/infra/compute_engines/ray/feature_builder.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
RayJoinNode,
1818
RayReadNode,
1919
RayTransformationNode,
20+
RayValidationNode,
2021
RayWriteNode,
2122
)
23+
from feast.types import from_feast_to_pyarrow_type
2224

2325
if TYPE_CHECKING:
2426
from feast.infra.compute_engines.ray.config import RayComputeEngineConfig
@@ -174,11 +176,29 @@ def build_output_nodes(self, view, final_node):
174176

175177
def build_validation_node(self, view, input_node):
176178
"""Build the validation node for feature validation."""
177-
# TODO: Implement validation logic
178-
logger.warning(
179-
"Feature validation is not yet implemented for Ray compute engine."
179+
expected_columns = {}
180+
if hasattr(view, "features"):
181+
for feature in view.features:
182+
try:
183+
expected_columns[feature.name] = from_feast_to_pyarrow_type(
184+
feature.dtype
185+
)
186+
except (ValueError, KeyError):
187+
logger.debug(
188+
"Could not resolve PyArrow type for feature '%s' "
189+
"(dtype=%s), skipping type check for this column.",
190+
feature.name,
191+
feature.dtype,
192+
)
193+
expected_columns[feature.name] = None
194+
195+
node = RayValidationNode(
196+
f"{view.name}:validate",
197+
expected_columns=expected_columns,
198+
inputs=[input_node],
180199
)
181-
return input_node
200+
self.nodes.append(node)
201+
return node
182202

183203
def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode:
184204
has_physical_source = (hasattr(view, "batch_source") and view.batch_source) or (

sdk/python/feast/infra/compute_engines/ray/nodes.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,3 +847,59 @@ def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame:
847847
),
848848
},
849849
)
850+
851+
852+
class RayValidationNode(DAGNode):
853+
"""
854+
Ray node for validating feature data against the declared schema.
855+
856+
Checks that all expected columns are present and logs warnings for
857+
type mismatches. Validation runs once on the first batch to avoid
858+
per-batch overhead; the full dataset is passed through unchanged.
859+
"""
860+
861+
def __init__(
862+
self,
863+
name: str,
864+
expected_columns: Dict[str, Optional[pa.DataType]],
865+
inputs: Optional[List[DAGNode]] = None,
866+
):
867+
super().__init__(name, inputs=inputs)
868+
self.expected_columns = expected_columns
869+
870+
def execute(self, context: ExecutionContext) -> DAGValue:
871+
input_value = self.get_single_input_value(context)
872+
dataset = input_value.data
873+
874+
if not self.expected_columns:
875+
context.node_outputs[self.name] = input_value
876+
return input_value
877+
878+
expected_names = set(self.expected_columns.keys())
879+
880+
schema = dataset.schema()
881+
actual_columns = set(schema.names)
882+
883+
missing = expected_names - actual_columns
884+
if missing:
885+
raise ValueError(
886+
f"[Validation: {self.name}] Missing expected columns: {missing}. "
887+
f"Actual columns: {sorted(actual_columns)}"
888+
)
889+
890+
for col_name, expected_type in self.expected_columns.items():
891+
if expected_type is None:
892+
continue
893+
actual_field = schema.field(col_name)
894+
if actual_field.type != expected_type:
895+
logger.warning(
896+
"[Validation: %s] Column '%s' type mismatch: expected %s, got %s",
897+
self.name,
898+
col_name,
899+
expected_type,
900+
actual_field.type,
901+
)
902+
903+
logger.debug("[Validation: %s] Schema validation passed.", self.name)
904+
context.node_outputs[self.name] = input_value
905+
return input_value

sdk/python/feast/infra/compute_engines/spark/feature_builder.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Union
23

34
from pyspark.sql import SparkSession
@@ -12,9 +13,13 @@
1213
SparkJoinNode,
1314
SparkReadNode,
1415
SparkTransformationNode,
16+
SparkValidationNode,
1517
SparkWriteNode,
1618
)
1719
from feast.infra.registry.base_registry import BaseRegistry
20+
from feast.types import from_feast_to_pyarrow_type
21+
22+
logger = logging.getLogger(__name__)
1823

1924

2025
class SparkFeatureBuilder(FeatureBuilder):
@@ -115,4 +120,26 @@ def build_output_nodes(self, view, input_node):
115120
return node
116121

117122
def build_validation_node(self, view, input_node):
118-
pass
123+
expected_columns = {}
124+
if hasattr(view, "features"):
125+
for feature in view.features:
126+
try:
127+
expected_columns[feature.name] = from_feast_to_pyarrow_type(
128+
feature.dtype
129+
)
130+
except (ValueError, KeyError):
131+
logger.debug(
132+
"Could not resolve PyArrow type for feature '%s' "
133+
"(dtype=%s), skipping type check for this column.",
134+
feature.name,
135+
feature.dtype,
136+
)
137+
expected_columns[feature.name] = None
138+
139+
node = SparkValidationNode(
140+
f"{view.name}:validate",
141+
expected_columns=expected_columns,
142+
inputs=[input_node],
143+
)
144+
self.nodes.append(node)
145+
return node

0 commit comments

Comments
 (0)