Skip to content

Commit d41becf

Browse files
authored
feat: Make online_write_batch_size configurable in MaterializationConfig (#6268)
* feat: make online_write_batch_size configurable in MaterializationConfig Signed-off-by: cutoutsy <cutoutsy@gmail.com> * refactor: simplify batch write logic and extend to spark/ray engines Signed-off-by: cutoutsy <cutoutsy@gmail.com> --------- Signed-off-by: cutoutsy <cutoutsy@gmail.com>
1 parent d0c8984 commit d41becf

File tree

5 files changed

+131
-35
lines changed

5 files changed

+131
-35
lines changed

sdk/python/feast/infra/compute_engines/local/nodes.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,16 +374,25 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
374374
for entity in self.feature_view.entity_columns
375375
}
376376

377-
rows_to_write = _convert_arrow_to_proto(
378-
input_table, self.feature_view, join_key_to_value_type
377+
batch_size = (
378+
context.repo_config.materialization_config.online_write_batch_size
379379
)
380-
381-
online_store.online_write_batch(
382-
config=context.repo_config,
383-
table=self.feature_view,
384-
data=rows_to_write,
385-
progress=lambda x: None,
380+
# Single batch if None (backward compatible), otherwise use configured batch_size
381+
batches = (
382+
[input_table]
383+
if batch_size is None
384+
else input_table.to_batches(max_chunksize=batch_size)
386385
)
386+
for batch in batches:
387+
rows_to_write = _convert_arrow_to_proto(
388+
batch, self.feature_view, join_key_to_value_type
389+
)
390+
online_store.online_write_batch(
391+
config=context.repo_config,
392+
table=self.feature_view,
393+
data=rows_to_write,
394+
progress=lambda x: None,
395+
)
387396

388397
if self.feature_view.offline:
389398
offline_store = context.offline_store

sdk/python/feast/infra/compute_engines/ray/utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,32 @@ def write_to_online_store(
4545
for entity in feature_view.entity_columns
4646
}
4747

48-
rows_to_write = _convert_arrow_to_proto(
49-
arrow_table, feature_view, join_key_to_value_type
48+
batch_size = repo_config.materialization_config.online_write_batch_size
49+
# Single batch if None (backward compatible), otherwise use configured batch_size
50+
batches = (
51+
[arrow_table]
52+
if batch_size is None
53+
else arrow_table.to_batches(max_chunksize=batch_size)
5054
)
5155

52-
if rows_to_write:
53-
online_store.online_write_batch(
54-
config=repo_config,
55-
table=feature_view,
56-
data=rows_to_write,
57-
progress=lambda x: None,
56+
total_rows = 0
57+
for batch in batches:
58+
rows_to_write = _convert_arrow_to_proto(
59+
batch, feature_view, join_key_to_value_type
5860
)
61+
62+
if rows_to_write:
63+
online_store.online_write_batch(
64+
config=repo_config,
65+
table=feature_view,
66+
data=rows_to_write,
67+
progress=lambda x: None,
68+
)
69+
total_rows += len(rows_to_write)
70+
71+
if total_rows > 0:
5972
logger.debug(
60-
f"Successfully wrote {len(rows_to_write)} rows to online store for {feature_view.name}"
73+
f"Successfully wrote {total_rows} rows to online store for {feature_view.name}"
6174
)
6275
else:
6376
logger.warning(f"No rows to write for {feature_view.name}")

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,24 @@ def map_in_arrow(
4747
for entity in feature_view.entity_columns
4848
}
4949

50-
rows_to_write = _convert_arrow_to_proto(
51-
table, feature_view, join_key_to_value_type
52-
)
53-
54-
online_store.online_write_batch(
55-
config=repo_config,
56-
table=feature_view,
57-
data=rows_to_write,
58-
progress=lambda x: None,
50+
batch_size = repo_config.materialization_config.online_write_batch_size
51+
# Single batch if None (backward compatible), otherwise use configured batch_size
52+
sub_batches = (
53+
[table]
54+
if batch_size is None
55+
else table.to_batches(max_chunksize=batch_size)
5956
)
57+
for sub_batch in sub_batches:
58+
rows_to_write = _convert_arrow_to_proto(
59+
sub_batch, feature_view, join_key_to_value_type
60+
)
61+
62+
online_store.online_write_batch(
63+
config=repo_config,
64+
table=feature_view,
65+
data=rows_to_write,
66+
progress=lambda x: None,
67+
)
6068
if mode == "offline":
6169
offline_store.offline_write_batch(
6270
config=repo_config,
@@ -95,15 +103,23 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
95103
for entity in feature_view.entity_columns
96104
}
97105

98-
rows_to_write = _convert_arrow_to_proto(
99-
table, feature_view, join_key_to_value_type
100-
)
101-
online_store.online_write_batch(
102-
repo_config,
103-
feature_view,
104-
rows_to_write,
105-
lambda x: None,
106+
batch_size = repo_config.materialization_config.online_write_batch_size
107+
# Single batch if None (backward compatible), otherwise use configured batch_size
108+
sub_batches = (
109+
[table]
110+
if batch_size is None
111+
else table.to_batches(max_chunksize=batch_size)
106112
)
113+
for sub_batch in sub_batches:
114+
rows_to_write = _convert_arrow_to_proto(
115+
sub_batch, feature_view, join_key_to_value_type
116+
)
117+
online_store.online_write_batch(
118+
repo_config,
119+
feature_view,
120+
rows_to_write,
121+
lambda x: None,
122+
)
107123

108124
yield pd.DataFrame(
109125
[pd.Series(range(1, 2))]

sdk/python/feast/repo_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ class MaterializationConfig(BaseModel):
214214
""" bool: If true, feature retrieval jobs will only pull the latest feature values for each entity.
215215
If false, feature retrieval jobs will pull all feature values within the specified time range. """
216216

217+
online_write_batch_size: Optional[int] = Field(default=None, gt=0)
218+
""" int: Number of rows to write to online store per batch during materialization.
219+
If None (default), all rows are written in a single batch for backward compatibility.
220+
Set to a positive integer (e.g., 10000) to enable batched writes.
221+
Supported compute engines: local, spark, ray. """
222+
217223

218224
class OpenLineageConfig(FeastBaseModel):
219225
"""Configuration for OpenLineage integration.

sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
LocalOutputNode,
1616
LocalTransformationNode,
1717
)
18+
from feast.repo_config import MaterializationConfig
1819

1920
backend = PandasBackend()
2021
now = pd.Timestamp.utcnow()
@@ -37,9 +38,11 @@
3738

3839
def create_context(node_outputs):
3940
# Setup execution context
41+
repo_config = MagicMock()
42+
repo_config.materialization_config = MaterializationConfig()
4043
return ExecutionContext(
4144
project="test_proj",
42-
repo_config=MagicMock(),
45+
repo_config=repo_config,
4346
offline_store=MagicMock(),
4447
online_store=MagicMock(),
4548
entity_defs=MagicMock(),
@@ -214,3 +217,52 @@ def test_local_output_node():
214217
node.inputs[0].name = "source"
215218
result = node.execute(context)
216219
assert result.num_rows == 4
220+
221+
222+
def test_local_output_node_online_write_default_batch():
223+
"""Test that online_write_batch is called once when batch_size is None (default)."""
224+
# Create a feature view with online=True
225+
feature_view = MagicMock()
226+
feature_view.online = True
227+
feature_view.offline = False
228+
feature_view.entity_columns = []
229+
230+
# Create context with default materialization config (batch_size=None)
231+
context = create_context(
232+
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))}
233+
)
234+
235+
node = LocalOutputNode("output", feature_view)
236+
node.add_input(MagicMock())
237+
node.inputs[0].name = "source"
238+
239+
node.execute(context)
240+
241+
# Verify online_write_batch was called exactly once (all rows in single batch)
242+
assert context.online_store.online_write_batch.call_count == 1
243+
244+
245+
def test_local_output_node_online_write_batched():
246+
"""Test that online_write_batch is called multiple times when batch_size is configured."""
247+
# Create a feature view with online=True
248+
feature_view = MagicMock()
249+
feature_view.online = True
250+
feature_view.offline = False
251+
feature_view.entity_columns = []
252+
253+
# Create context with batch_size=2 (sample_df has 4 rows, so expect 2 batches)
254+
context = create_context(
255+
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))}
256+
)
257+
context.repo_config.materialization_config = MaterializationConfig(
258+
online_write_batch_size=2
259+
)
260+
261+
node = LocalOutputNode("output", feature_view)
262+
node.add_input(MagicMock())
263+
node.inputs[0].name = "source"
264+
265+
node.execute(context)
266+
267+
# Verify online_write_batch was called twice (4 rows / batch_size 2 = 2 batches)
268+
assert context.online_store.online_write_batch.call_count == 2

0 commit comments

Comments
 (0)