Skip to content
Prev Previous commit
Next Next commit
feat: added batching to feature server /push to offline store ([#5683](
…#5683))

Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com>
  • Loading branch information
jfw-ppi committed Jan 1, 2026
commit 274740539ee80bb089f58468a5b0db2abb52c9f9
111 changes: 78 additions & 33 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datetime import datetime
from importlib import resources as importlib_resources
from types import SimpleNamespace
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Union
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Union

import pandas as pd
import psutil
Expand Down Expand Up @@ -395,7 +395,7 @@ async def retrieve_online_documents(
return response_dict

@app.post("/push", dependencies=[Depends(inject_user_details)])
async def push(request: PushFeaturesRequest) -> None:
async def push(request: PushFeaturesRequest) -> Response:
df = pd.DataFrame(request.df)
actions = []
if request.to == "offline":
Expand Down Expand Up @@ -470,6 +470,8 @@ async def _push_with_to(push_to: PushMode) -> None:
needs_online = to in (PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE)
needs_offline = to in (PushMode.OFFLINE, PushMode.ONLINE_AND_OFFLINE)

status_code = status.HTTP_200_OK

if offline_batcher is None or not needs_offline:
await _push_with_to(to)
else:
Expand All @@ -482,6 +484,9 @@ async def _push_with_to(push_to: PushMode) -> None:
allow_registry_cache=request.allow_registry_cache,
transform_on_write=request.transform_on_write,
)
status_code = status.HTTP_202_ACCEPTED

return Response(status_code=status_code)

async def _get_feast_object(
feature_view_name: str, allow_registry_cache: bool
Expand Down Expand Up @@ -851,6 +856,7 @@ def __init__(self, store: "feast.FeatureStore", cfg: Any):
list
)
self._last_flush: DefaultDict[_OfflineBatchKey, float] = defaultdict(time.time)
self._inflight: Set[_OfflineBatchKey] = set()

self._lock = threading.Lock()
self._stop_event = threading.Event()
Expand Down Expand Up @@ -889,24 +895,25 @@ def enqueue(
with self._lock:
self._buffers[key].append(df)
total_rows = sum(len(d) for d in self._buffers[key])
should_flush = total_rows >= self._cfg.batch_size

if should_flush:
# Size-based flush
if total_rows >= self._cfg.batch_size:
logger.debug(
"OfflineWriteBatcher size threshold reached for %s: %s rows",
key,
total_rows,
)
self._flush_locked(key)
logger.debug(
"OfflineWriteBatcher size threshold reached for %s: %s rows",
key,
total_rows,
)
self._flush(key)

def flush_all(self) -> None:
"""
Flush all buffers synchronously. Intended for graceful shutdown.
"""
with self._lock:
keys = list(self._buffers.keys())
for key in keys:
self._flush_locked(key)
for key in keys:
self._flush(key)

def shutdown(self, timeout: float = 5.0) -> None:
"""
Expand Down Expand Up @@ -942,6 +949,7 @@ def _run(self) -> None:
now = time.time()
try:
with self._lock:
keys_to_flush: List[_OfflineBatchKey] = []
for key, dfs in list(self._buffers.items()):
if not dfs:
continue
Expand All @@ -955,38 +963,75 @@ def _run(self) -> None:
key,
age,
)
self._flush_locked(key)
keys_to_flush.append(key)
for key in keys_to_flush:
self._flush(key)
except Exception:
logger.exception("Error in OfflineWriteBatcher background loop")

logger.debug("OfflineWriteBatcher background loop exiting")

def _flush_locked(self, key: _OfflineBatchKey) -> None:
def _drain_locked(self, key: _OfflineBatchKey) -> Optional[List[pd.DataFrame]]:
"""
Flush a single buffer; caller must hold self._lock.
Drain a single buffer; caller must hold self._lock.
"""
if key in self._inflight:
return None

dfs = self._buffers.get(key)
if not dfs:
return
return None

batch_df = pd.concat(dfs, ignore_index=True)
self._buffers[key].clear()
self._last_flush[key] = time.time()
self._buffers[key] = []
self._inflight.add(key)
return dfs

logger.debug(
"Flushing offline batch for push_source=%s with %s rows",
key.push_source_name,
len(batch_df),
)
def _flush(self, key: _OfflineBatchKey) -> None:
"""
Flush a single buffer. Extracts data under lock, then does I/O without lock.
"""
while True:
with self._lock:
dfs = self._drain_locked(key)

# NOTE: offline writes are currently synchronous only, so we call directly
try:
self._store.push(
push_source_name=key.push_source_name,
df=batch_df,
allow_registry_cache=key.allow_registry_cache,
to=PushMode.OFFLINE,
transform_on_write=key.transform_on_write,
if not dfs:
return

batch_df = pd.concat(dfs, ignore_index=True)

# NOTE: offline writes are currently synchronous only, so we call directly
try:
self._store.push(
push_source_name=key.push_source_name,
df=batch_df,
allow_registry_cache=key.allow_registry_cache,
to=PushMode.OFFLINE,
transform_on_write=key.transform_on_write,
)
except Exception:
logger.exception("Error flushing offline batch for %s", key)
with self._lock:
self._buffers[key] = dfs + self._buffers[key]
self._inflight.discard(key)
return

logger.debug(
"Flushing offline batch for push_source=%s with %s rows",
key.push_source_name,
len(batch_df),
)

with self._lock:
self._last_flush[key] = time.time()
self._inflight.discard(key)
pending_rows = sum(len(d) for d in self._buffers.get(key, []))
should_flush = pending_rows >= self._cfg.batch_size

if not should_flush:
return

logger.debug(
"OfflineWriteBatcher size threshold reached for %s: %s rows",
key,
pending_rows,
)
except Exception:
logger.exception("Error flushing offline batch for %s", key)
14 changes: 8 additions & 6 deletions sdk/python/tests/unit/test_feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _enable_offline_batching_config(
fs, enabled: bool = True, batch_size: int = 1, batch_interval_seconds: int = 60
):
"""
Attach a minimal feature_server.offline_push_batching config
Attach a minimal feature_server.offline_push_batching config
to a mocked FeatureStore.
"""
if not hasattr(fs, "config") or fs.config is None:
Expand Down Expand Up @@ -301,7 +301,9 @@ def test_push_batched_matrix(

# use a multi-row payload to ensure we test non-trivial dfs
resp = client.post("/push", json=push_body_many(push_mode, count=2, id_start=100))
assert resp.status_code == 200
needs_offline = push_mode in (PushMode.OFFLINE, PushMode.ONLINE_AND_OFFLINE)
expected_status = 202 if batching_enabled and needs_offline else 200
assert resp.status_code == expected_status

# Collect calls
sync_calls = fs.push.call_args_list
Expand Down Expand Up @@ -391,19 +393,19 @@ def test_offline_batches_are_separated_by_flags(mock_fs_factory):

# 1) Default flags: allow_registry_cache=True, transform_on_write=True
resp1 = client.post("/push", json=body_base)
assert resp1.status_code == 200
assert resp1.status_code == 202

# 2) Different allow_registry_cache
body_allow_false = dict(body_base)
body_allow_false["allow_registry_cache"] = False
resp2 = client.post("/push", json=body_allow_false)
assert resp2.status_code == 200
assert resp2.status_code == 202

# 3) Different transform_on_write
body_transform_false = dict(body_base)
body_transform_false["transform_on_write"] = False
resp3 = client.post("/push", json=body_transform_false)
assert resp3.status_code == 200
assert resp3.status_code == 202

# Immediately after: no flush expected yet (interval-based)
assert fs.push.call_count == 0
Expand Down Expand Up @@ -447,7 +449,7 @@ def test_offline_batcher_interval_flush(mock_fs_factory):
resp = client.post(
"/push", json=push_body_many(PushMode.OFFLINE, count=2, id_start=500)
)
assert resp.status_code == 200
assert resp.status_code == 202

# Immediately after: no sync push yet (buffer only)
assert fs.push.call_count == 0
Expand Down
Loading