# Copyright 2025 The Feast Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import os import sys import threading import time import traceback from collections import defaultdict from contextlib import asynccontextmanager from datetime import datetime from importlib import resources as importlib_resources from types import SimpleNamespace from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Union import pandas as pd from dateutil import parser from fastapi import ( Depends, FastAPI, Request, Response, WebSocket, WebSocketDisconnect, status, ) from fastapi.concurrency import run_in_threadpool from fastapi.logger import logger from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from google.protobuf.json_format import MessageToDict from pydantic import BaseModel, field_validator import feast from feast import metrics as feast_metrics from feast import proto_json, utils from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL from feast.data_source import PushMode from feast.errors import ( FeastError, ) from feast.feast_object import FeastObject from feast.feature_view_utils import get_feature_view_from_feature_store from feast.permissions.action import WRITE, AuthzedAction from feast.permissions.security_manager import assert_permissions from feast.permissions.server.rest import inject_user_details from feast.permissions.server.utils import ( ServerType, init_auth_manager, init_security_manager, str_to_auth_manager_type, ) # TODO: deprecate this in favor of push features class WriteToFeatureStoreRequest(BaseModel): feature_view_name: str df: dict allow_registry_cache: bool = True transform_on_write: bool = True class PushFeaturesRequest(BaseModel): push_source_name: str df: dict allow_registry_cache: bool = True to: str = "online" transform_on_write: bool = True class MaterializeRequest(BaseModel): start_ts: Optional[str] = None end_ts: Optional[str] = None feature_views: Optional[List[str]] = None disable_event_timestamp: bool = False full_feature_names: bool = False class MaterializeIncrementalRequest(BaseModel): end_ts: str feature_views: Optional[List[str]] = None full_feature_names: bool = False class GetOnlineFeaturesRequest(BaseModel): entities: Dict[str, List[Any]] feature_service: Optional[str] = None features: List[str] = [] full_feature_names: bool = False include_feature_view_version_metadata: bool = False class GetOnlineDocumentsRequest(BaseModel): feature_service: Optional[str] = None features: List[str] = [] full_feature_names: bool = False include_feature_view_version_metadata: bool = False top_k: Optional[int] = None query: Optional[List[float]] = None query_string: Optional[str] = None api_version: Optional[int] = 1 class FeatureVectorResponse(BaseModel): values: List[Any] = [] statuses: List[str] = [] event_timestamps: List[str] = [] class OnlineFeaturesMetadataResponse(BaseModel): feature_names: List[str] = [] @field_validator("feature_names", mode="before") @classmethod def _unwrap_feature_list(cls, v: Any) -> Any: """Accept both the proto_json-patched flat list and the raw protobuf ``{"val": [...]}`` dict produced by ``MessageToDict`` when the monkey-patch is absent or ineffective.""" if isinstance(v, dict) and "val" in v: return v["val"] return v class OnlineFeaturesResponse(BaseModel): metadata: Optional[OnlineFeaturesMetadataResponse] = None results: List[FeatureVectorResponse] = [] status: Optional[bool] = None class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] def _resolve_feature_counts( features: Union[List[str], "feast.FeatureService"], ) -> tuple: """Return (feature_count, feature_view_count) from the resolved features. ``features`` is either a list of ``"feature_view:feature"`` strings or a ``FeatureService`` with ``feature_view_projections``. """ from feast.feature_service import FeatureService if isinstance(features, FeatureService): projections = features.feature_view_projections fv_count = len(projections) feat_count = sum(len(p.features) for p in projections) elif isinstance(features, list): feat_count = len(features) fv_names = {ref.split(":")[0].split("@")[0] for ref in features if ":" in ref} fv_count = len(fv_names) else: feat_count = 0 fv_count = 0 return str(feat_count), str(fv_count) async def _get_features( request: Union[GetOnlineFeaturesRequest, GetOnlineDocumentsRequest], store: "feast.FeatureStore", ): if request.feature_service: feature_service = await run_in_threadpool( store.get_feature_service, request.feature_service, allow_cache=True ) assert_permissions( resource=feature_service, actions=[AuthzedAction.READ_ONLINE] ) features = feature_service # type: ignore else: all_feature_views, all_on_demand_feature_views = await run_in_threadpool( utils._get_feature_views_to_use, store.registry, store.project, request.features, allow_cache=True, hide_dummy_entity=False, ) for feature_view in all_feature_views: assert_permissions( resource=feature_view, actions=[AuthzedAction.READ_ONLINE] ) for od_feature_view in all_on_demand_feature_views: assert_permissions( resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] ) features = request.features # type: ignore return features async def load_static_artifacts(app: FastAPI, store): """ Load static artifacts (models, lookup tables, etc.) into app.state. This function can be extended to load various types of static artifacts: - Small ML models (scikit-learn, small neural networks) - Lookup tables and reference data - Configuration parameters - Pre-computed embeddings Note: Not recommended for large language models - use dedicated model serving solutions (vLLM, TGI, etc.) for those. """ try: # Import here to avoid loading heavy dependencies unless needed import importlib.util import inspect from pathlib import Path # Look for static artifacts loading in the feature repository # This allows templates and users to define their own artifact loading repo_path = Path(store.repo_path) if store.repo_path else Path.cwd() artifacts_file = repo_path / "static_artifacts.py" if artifacts_file.exists(): # Load and execute custom static artifacts loading spec = importlib.util.spec_from_file_location( "static_artifacts", artifacts_file ) if spec and spec.loader: artifacts_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(artifacts_module) # Look for load_artifacts function if hasattr(artifacts_module, "load_artifacts"): load_func = artifacts_module.load_artifacts if inspect.iscoroutinefunction(load_func): await load_func(app) else: load_func(app) logger.info("Loaded static artifacts from static_artifacts.py") except Exception as e: # Non-fatal error - feature server should still start logger.warning(f"Failed to load static artifacts: {e}") def get_app( store: "feast.FeatureStore", registry_ttl_sec: int = DEFAULT_FEATURE_SERVER_REGISTRY_TTL, ): """ Creates a FastAPI app that can be used to start a feature server. Args: store: The FeatureStore to use for serving features registry_ttl_sec: The TTL in seconds for the registry cache Returns: A FastAPI app Example: ```python from feast import FeatureStore store = FeatureStore(repo_path="feature_repo") app = get_app(store) ``` The app provides the following endpoints: - `/get-online-features`: Get online features - `/retrieve-online-documents`: Retrieve online documents - `/push`: Push features to the feature store - `/write-to-online-store`: Write to the online store - `/health`: Health check - `/materialize`: Materialize features - `/materialize-incremental`: Materialize features incrementally - `/chat`: Chat UI - `/ws/chat`: WebSocket endpoint for chat MCP Support: - If MCP is enabled in feature server configuration, MCP endpoints will be added automatically """ proto_json.patch() # Asynchronously refresh registry, notifying shutdown and canceling the active timer if the app is shutting down registry_proto = None shutting_down = False active_timer: Optional[threading.Timer] = None # --- Offline write batching config and batcher --- fs_cfg = getattr(store.config, "feature_server", None) batching_cfg = None if fs_cfg is not None: enabled = getattr(fs_cfg, "offline_push_batching_enabled", False) batch_size = getattr(fs_cfg, "offline_push_batching_batch_size", None) batch_interval_seconds = getattr( fs_cfg, "offline_push_batching_batch_interval_seconds", None ) if enabled is True: size_ok = isinstance(batch_size, int) and not isinstance(batch_size, bool) interval_ok = isinstance(batch_interval_seconds, int) and not isinstance( batch_interval_seconds, bool ) if size_ok and interval_ok: batching_cfg = SimpleNamespace( enabled=True, batch_size=batch_size, batch_interval_seconds=batch_interval_seconds, ) else: logger.warning( "Offline write batching enabled but missing or invalid numeric values; " "disabling batching (batch_size=%r, batch_interval_seconds=%r)", batch_size, batch_interval_seconds, ) offline_batcher: Optional[OfflineWriteBatcher] = None if batching_cfg is not None and batching_cfg.enabled is True: offline_batcher = OfflineWriteBatcher(store=store, cfg=batching_cfg) logger.debug("Offline write batching is ENABLED") else: logger.debug("Offline write batching is DISABLED") def stop_refresh(): nonlocal shutting_down shutting_down = True if active_timer: active_timer.cancel() def async_refresh(): if shutting_down: return store.refresh_registry() nonlocal registry_proto registry_proto = store.registry.proto() if registry_ttl_sec: nonlocal active_timer active_timer = threading.Timer(registry_ttl_sec, async_refresh) active_timer.start() @asynccontextmanager async def lifespan(app: FastAPI): # Load static artifacts before initializing store await load_static_artifacts(app, store) await store.initialize() async_refresh() try: yield finally: stop_refresh() if offline_batcher is not None: offline_batcher.shutdown() await store.close() app = FastAPI(lifespan=lifespan) @app.post( "/get-online-features", dependencies=[Depends(inject_user_details)], response_model=OnlineFeaturesResponse, ) async def get_online_features(request: GetOnlineFeaturesRequest) -> Any: with feast_metrics.track_request_latency( "/get-online-features", ) as metrics_ctx: features = await _get_features(request, store) feat_count, fv_count = _resolve_feature_counts(features) metrics_ctx.feature_count = feat_count metrics_ctx.feature_view_count = fv_count entity_count = len(next(iter(request.entities.values()), [])) feast_metrics.track_online_features_entities(entity_count) read_params = dict( features=features, entity_rows=request.entities, full_feature_names=request.full_feature_names, include_feature_view_version_metadata=request.include_feature_view_version_metadata, ) if store._get_provider().async_supported.online.read: response = await store.get_online_features_async(**read_params) # type: ignore else: response = await run_in_threadpool( lambda: store.get_online_features(**read_params) # type: ignore ) response_dict = await run_in_threadpool( MessageToDict, response.proto, preserving_proto_field_name=True, float_precision=18, ) return response_dict @app.post( "/retrieve-online-documents", dependencies=[Depends(inject_user_details)], response_model=OnlineFeaturesResponse, ) async def retrieve_online_documents( request: GetOnlineDocumentsRequest, ) -> Any: with feast_metrics.track_request_latency("/retrieve-online-documents"): logger.warning( "This endpoint is in alpha and will be moved to /get-online-features when stable." ) features = await _get_features(request, store) read_params = dict( features=features, query=request.query, top_k=request.top_k, ) if request.api_version == 2 and request.query_string is not None: read_params["query_string"] = request.query_string if request.api_version == 2: read_params["include_feature_view_version_metadata"] = ( request.include_feature_view_version_metadata ) response = await run_in_threadpool( lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore ) else: response = await run_in_threadpool( lambda: store.retrieve_online_documents(**read_params) # type: ignore ) response_dict = await run_in_threadpool( MessageToDict, response.proto, preserving_proto_field_name=True, float_precision=18, ) return response_dict @app.post("/push", dependencies=[Depends(inject_user_details)]) async def push(request: PushFeaturesRequest) -> Response: with feast_metrics.track_request_latency("/push"): df = pd.DataFrame(request.df) actions = [] if request.to == "offline": to = PushMode.OFFLINE actions = [AuthzedAction.WRITE_OFFLINE] elif request.to == "online": to = PushMode.ONLINE actions = [AuthzedAction.WRITE_ONLINE] elif request.to == "online_and_offline": to = PushMode.ONLINE_AND_OFFLINE actions = WRITE else: raise ValueError( f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." ) from feast.data_source import PushSource all_fvs = store.list_feature_views( allow_cache=request.allow_registry_cache ) + store.list_stream_feature_views( allow_cache=request.allow_registry_cache ) fvs_with_push_sources = { fv for fv in all_fvs if ( fv.stream_source is not None and isinstance(fv.stream_source, PushSource) and fv.stream_source.name == request.push_source_name ) } for feature_view in fvs_with_push_sources: assert_permissions(resource=feature_view, actions=actions) async def _push_with_to(push_to: PushMode) -> None: """ Helper for performing a single push operation. NOTE: - Feast providers **do not currently support async offline writes**. - Therefore: * ONLINE and ONLINE_AND_OFFLINE → may be async, depending on provider.async_supported.online.write * OFFLINE → always synchronous, but executed via run_in_threadpool when called from HTTP handlers. - The OfflineWriteBatcher handles offline writes directly in its own background thread, but the offline store writes are currently synchronous only. """ push_source_name = request.push_source_name allow_registry_cache = request.allow_registry_cache transform_on_write = request.transform_on_write # Async currently only applies to online store writes (ONLINE / ONLINE_AND_OFFLINE paths) as theres no async for offline store if push_to in (PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE) and ( store._get_provider().async_supported.online.write ): await store.push_async( push_source_name=push_source_name, df=df, allow_registry_cache=allow_registry_cache, to=push_to, transform_on_write=transform_on_write, ) else: await run_in_threadpool( lambda: store.push( push_source_name=push_source_name, df=df, allow_registry_cache=allow_registry_cache, to=push_to, transform_on_write=transform_on_write, ) ) needs_online = to in (PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE) needs_offline = to in (PushMode.OFFLINE, PushMode.ONLINE_AND_OFFLINE) status_code = status.HTTP_200_OK if offline_batcher is None or not needs_offline: await _push_with_to(to) else: if needs_online: await _push_with_to(PushMode.ONLINE) offline_batcher.enqueue( push_source_name=request.push_source_name, df=df, allow_registry_cache=request.allow_registry_cache, transform_on_write=request.transform_on_write, ) status_code = status.HTTP_202_ACCEPTED feast_metrics.track_push(request.push_source_name, request.to) return Response(status_code=status_code) async def _get_feast_object( feature_view_name: str, allow_registry_cache: bool ) -> FeastObject: return await run_in_threadpool( get_feature_view_from_feature_store, store, feature_view_name, allow_registry_cache, ) @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None: df = pd.DataFrame(request.df) feature_view_name = request.feature_view_name allow_registry_cache = request.allow_registry_cache resource = await _get_feast_object(feature_view_name, allow_registry_cache) assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE]) await run_in_threadpool( store.write_to_online_store, feature_view_name=feature_view_name, df=df, allow_registry_cache=allow_registry_cache, transform_on_write=request.transform_on_write, ) @app.get("/health") async def health(): return ( Response(status_code=status.HTTP_200_OK) if registry_proto else Response(status_code=status.HTTP_503_SERVICE_UNAVAILABLE) ) @app.post("/chat") async def chat(request: ChatRequest): # Process the chat request # For now, just return dummy text return {"response": "This is a dummy response from the Feast feature server."} @app.get("/chat") async def chat_ui(): # Serve the chat UI static_dir_ref = importlib_resources.files(__spec__.parent) / "static/chat" # type: ignore[name-defined, arg-type] with importlib_resources.as_file(static_dir_ref) as static_dir: with open(os.path.join(static_dir, "index.html")) as f: content = f.read() return Response(content=content, media_type="text/html") @app.post("/materialize", dependencies=[Depends(inject_user_details)]) async def materialize(request: MaterializeRequest) -> None: with feast_metrics.track_request_latency("/materialize"): for feature_view in request.feature_views or []: resource = await _get_feast_object(feature_view, True) assert_permissions( resource=resource, actions=[AuthzedAction.WRITE_ONLINE], ) if request.disable_event_timestamp: now = datetime.now() start_date = datetime(1970, 1, 1) end_date = now else: if not request.start_ts or not request.end_ts: raise ValueError( "start_ts and end_ts are required when disable_event_timestamp is False" ) start_date = utils.make_tzaware(parser.parse(request.start_ts)) end_date = utils.make_tzaware(parser.parse(request.end_ts)) await run_in_threadpool( store.materialize, start_date, end_date, request.feature_views, disable_event_timestamp=request.disable_event_timestamp, full_feature_names=request.full_feature_names, ) @app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)]) async def materialize_incremental(request: MaterializeIncrementalRequest) -> None: with feast_metrics.track_request_latency("/materialize-incremental"): for feature_view in request.feature_views or []: resource = await _get_feast_object(feature_view, True) assert_permissions( resource=resource, actions=[AuthzedAction.WRITE_ONLINE], ) await run_in_threadpool( store.materialize_incremental, utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views, full_feature_names=request.full_feature_names, ) @app.exception_handler(Exception) async def rest_exception_handler(request: Request, exc: Exception): # Print the original exception on the server side logger.exception(traceback.format_exc()) if isinstance(exc, FeastError): return JSONResponse( status_code=exc.http_status_code(), content=exc.to_error_detail(), ) else: return JSONResponse( status_code=500, content=str(exc), ) # Chat WebSocket connection manager class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket): self.active_connections.remove(websocket) async def send_message(self, message: str, websocket: WebSocket): await websocket.send_text(message) manager = ConnectionManager() MAX_WS_CONNECTIONS = 5 MAX_MESSAGE_SIZE = 4096 MAX_MESSAGES_PER_MINUTE = 60 WS_READ_TIMEOUT_SEC = 60 @app.websocket("/ws/chat") async def websocket_endpoint(websocket: WebSocket): if len(manager.active_connections) >= MAX_WS_CONNECTIONS: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return await manager.connect(websocket) message_timestamps: List[float] = [] try: while True: try: message = await asyncio.wait_for( websocket.receive_text(), timeout=WS_READ_TIMEOUT_SEC ) except asyncio.TimeoutError: await websocket.close(code=status.WS_1001_GOING_AWAY) return if len(message) > MAX_MESSAGE_SIZE: await websocket.close(code=status.WS_1009_MESSAGE_TOO_BIG) return now = time.time() cutoff = now - 60 message_timestamps = [ts for ts in message_timestamps if ts >= cutoff] if len(message_timestamps) >= MAX_MESSAGES_PER_MINUTE: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return message_timestamps.append(now) # Process the received message (currently unused but kept for future implementation) # For now, just return dummy text response = f"You sent: '{message}'. This is a dummy response from the Feast feature server." # Stream the response word by word words = response.split() for word in words: await manager.send_message(word + " ", websocket) await asyncio.sleep(0.1) # Add a small delay between words except WebSocketDisconnect: manager.disconnect(websocket) # Mount static files static_dir_ref = importlib_resources.files(__spec__.parent) / "static" # type: ignore[name-defined, arg-type] with importlib_resources.as_file(static_dir_ref) as static_dir: app.mount("/static", StaticFiles(directory=static_dir), name="static") # Add MCP support if enabled in feature server configuration _add_mcp_support_if_enabled(app, store) return app def _add_mcp_support_if_enabled(app, store: "feast.FeatureStore"): """Add MCP support to the FastAPI app if enabled in configuration.""" mcp_transport_not_supported_error = None try: # Check if MCP is enabled in feature server config if ( store.config.feature_server and hasattr(store.config.feature_server, "type") and store.config.feature_server.type == "mcp" and getattr(store.config.feature_server, "mcp_enabled", False) ): try: from feast.infra.mcp_servers.mcp_server import ( McpTransportNotSupportedError, add_mcp_support_to_app, ) mcp_transport_not_supported_error = McpTransportNotSupportedError except ImportError as e: logger.error(f"Error checking/adding MCP support: {e}") return mcp_server = add_mcp_support_to_app(app, store, store.config.feature_server) if mcp_server: logger.info("MCP support has been enabled for the Feast feature server") else: logger.warning("MCP support was requested but could not be enabled") else: logger.debug("MCP support is not enabled in feature server configuration") except Exception as e: if mcp_transport_not_supported_error and isinstance( e, mcp_transport_not_supported_error ): raise logger.error(f"Error checking/adding MCP support: {e}") # Don't fail the entire server if MCP fails to initialize if sys.platform != "win32": import gunicorn.app.base class FeastServeApplication(gunicorn.app.base.BaseApplication): def __init__( self, store: "feast.FeatureStore", metrics_enabled: bool = False, **options ): self._app = get_app( store=store, registry_ttl_sec=options["registry_ttl_sec"], ) self._options = options self._metrics_enabled = metrics_enabled super().__init__() def load_config(self): for key, value in self._options.items(): if key.lower() in self.cfg.settings and value is not None: self.cfg.set(key.lower(), value) self.cfg.set("worker_class", "uvicorn_worker.UvicornWorker") if self._metrics_enabled: self.cfg.set("post_worker_init", _gunicorn_post_worker_init) self.cfg.set("child_exit", _gunicorn_child_exit) def load(self): return self._app def _gunicorn_post_worker_init(worker): """Start per-worker resource monitoring after Gunicorn forks.""" feast_metrics.init_worker_monitoring() def _gunicorn_child_exit(server, worker): """Clean up Prometheus metric files for a dead worker.""" feast_metrics.mark_process_dead(worker.pid) def start_server( store: "feast.FeatureStore", host: str, port: int, no_access_log: bool, workers: int, worker_connections: int, max_requests: int, max_requests_jitter: int, keep_alive_timeout: int, registry_ttl_sec: int, tls_key_path: str, tls_cert_path: str, metrics: bool, ): if (tls_key_path and not tls_cert_path) or (not tls_key_path and tls_cert_path): raise ValueError( "Both key and cert file paths are required to start server in TLS mode." ) fs_cfg = getattr(store.config, "feature_server", None) metrics_cfg = getattr(fs_cfg, "metrics", None) metrics_from_config = getattr(metrics_cfg, "enabled", False) metrics_active = metrics or metrics_from_config uses_gunicorn = sys.platform != "win32" if metrics_active: flags = feast_metrics.build_metrics_flags(metrics_cfg) feast_metrics.start_metrics_server( store, metrics_config=flags, start_resource_monitoring=not uses_gunicorn, ) logger.debug("start_server called") auth_type = str_to_auth_manager_type(store.config.auth_config.type) logger.info(f"Auth type: {auth_type}") init_security_manager(auth_type=auth_type, fs=store) logger.debug("Security manager initialized successfully") init_auth_manager( auth_type=auth_type, server_type=ServerType.REST, auth_config=store.config.auth_config, ) logger.debug("Auth manager initialized successfully") if uses_gunicorn: options = { "bind": f"{host}:{port}", "accesslog": None if no_access_log else "-", "workers": workers, "worker_connections": worker_connections, "max_requests": max_requests, "max_requests_jitter": max_requests_jitter, "keepalive": keep_alive_timeout, "registry_ttl_sec": registry_ttl_sec, } # Add SSL options if the paths exist if tls_key_path and tls_cert_path: options["keyfile"] = tls_key_path options["certfile"] = tls_cert_path FeastServeApplication( store=store, metrics_enabled=metrics_active, **options ).run() else: import uvicorn app = get_app(store, registry_ttl_sec) if tls_key_path and tls_cert_path: uvicorn.run( app, host=host, port=port, access_log=(not no_access_log), ssl_keyfile=tls_key_path, ssl_certfile=tls_cert_path, ) else: uvicorn.run(app, host=host, port=port, access_log=(not no_access_log)) class _OfflineBatchKey(NamedTuple): push_source_name: str allow_registry_cache: bool transform_on_write: bool class OfflineWriteBatcher: """ In-process offline write batcher for /push requests. - Buffers DataFrames per (push_source_name, allow_registry_cache, transform_on_write) - Flushes when either: * total rows in a buffer >= batch_size, or * time since last flush >= batch_interval_seconds - Flush runs in a dedicated background thread so the HTTP event loop stays unblocked. """ def __init__(self, store: "feast.FeatureStore", cfg: Any): self._store = store self._cfg = cfg # Buffers and timestamps keyed by batch key self._buffers: DefaultDict[_OfflineBatchKey, List[pd.DataFrame]] = defaultdict( list ) self._last_flush: DefaultDict[_OfflineBatchKey, float] = defaultdict(time.time) self._inflight: Set[_OfflineBatchKey] = set() self._lock = threading.Lock() self._stop_event = threading.Event() # Start background flusher thread self._thread = threading.Thread( target=self._run, name="offline_write_batcher", daemon=True ) self._thread.start() logger.debug( "OfflineWriteBatcher initialized: batch_size=%s, batch_interval_seconds=%s", getattr(cfg, "batch_size", None), getattr(cfg, "batch_interval_seconds", None), ) # ---------- Public API ---------- def enqueue( self, push_source_name: str, df: pd.DataFrame, allow_registry_cache: bool, transform_on_write: bool, ) -> None: """ Enqueue a dataframe for offline write, grouped by push source + flags. Cheap and non-blocking; heavy I/O happens in background thread. """ key = _OfflineBatchKey( push_source_name=push_source_name, allow_registry_cache=allow_registry_cache, transform_on_write=transform_on_write, ) with self._lock: self._buffers[key].append(df) total_rows = sum(len(d) for d in self._buffers[key]) should_flush = total_rows >= self._cfg.batch_size if should_flush: # Size-based flush logger.debug( "OfflineWriteBatcher size threshold reached for %s: %s rows", key, total_rows, ) self._flush(key) def flush_all(self) -> None: """ Flush all buffers synchronously. Intended for graceful shutdown. """ with self._lock: keys = list(self._buffers.keys()) for key in keys: self._flush(key) def shutdown(self, timeout: float = 5.0) -> None: """ Stop the background thread and perform a best-effort flush. """ logger.debug("Shutting down OfflineWriteBatcher") self._stop_event.set() try: self._thread.join(timeout=timeout) except Exception: logger.exception("Error joining OfflineWriteBatcher thread") # Best-effort final flush try: self.flush_all() except Exception: logger.exception("Error during final OfflineWriteBatcher flush") # ---------- Internal helpers ---------- def _run(self) -> None: """ Background loop: periodically checks for buffers that should be flushed based on time since last flush. """ interval = max(1, int(getattr(self._cfg, "batch_interval_seconds", 30))) logger.debug( "OfflineWriteBatcher background loop started with check interval=%s", interval, ) while not self._stop_event.wait(timeout=interval): now = time.time() try: with self._lock: keys_to_flush: List[_OfflineBatchKey] = [] for key, dfs in list(self._buffers.items()): if not dfs: continue last = self._last_flush[ key ] # this will also init the default timestamp age = now - last if age >= self._cfg.batch_interval_seconds: logger.debug( "OfflineWriteBatcher time threshold reached for %s: age=%s", key, age, ) keys_to_flush.append(key) for key in keys_to_flush: self._flush(key) except Exception: logger.exception("Error in OfflineWriteBatcher background loop") logger.debug("OfflineWriteBatcher background loop exiting") def _drain_locked(self, key: _OfflineBatchKey) -> Optional[List[pd.DataFrame]]: """ Drain a single buffer; caller must hold self._lock. """ if key in self._inflight: return None dfs = self._buffers.get(key) if not dfs: return None self._buffers[key] = [] self._inflight.add(key) return dfs def _flush(self, key: _OfflineBatchKey) -> None: """ Flush a single buffer. Extracts data under lock, then does I/O without lock. """ while True: with self._lock: dfs = self._drain_locked(key) if not dfs: return batch_df = pd.concat(dfs, ignore_index=True) # NOTE: offline writes are currently synchronous only, so we call directly try: self._store.push( push_source_name=key.push_source_name, df=batch_df, allow_registry_cache=key.allow_registry_cache, to=PushMode.OFFLINE, transform_on_write=key.transform_on_write, ) except Exception: logger.exception("Error flushing offline batch for %s", key) with self._lock: self._buffers[key] = dfs + self._buffers[key] self._inflight.discard(key) return logger.debug( "Flushing offline batch for push_source=%s with %s rows", key.push_source_name, len(batch_df), ) with self._lock: self._last_flush[key] = time.time() self._inflight.discard(key) pending_rows = sum(len(d) for d in self._buffers.get(key, [])) should_flush = pending_rows >= self._cfg.batch_size if not should_flush: return logger.debug( "OfflineWriteBatcher size threshold reached for %s: %s rows", key, pending_rows, )