Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
545 changes: 540 additions & 5 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ rag = [
]
ray = [
'ray>=2.47.0; python_version == "3.10"',
'codeflare-sdk>=0.31.1; python_version > "3.10"'
'codeflare-sdk>=0.31.1; python_version > "3.10"',
"datasets>=3.6.0",
]
redis = [
"redis>=4.2.2,<8",
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/dbt/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def parse(self) -> None:
try:
from dbt_artifacts_parser.parser import parse_manifest

assert self._raw_manifest is not None
self._parsed_manifest = parse_manifest(manifest=self._raw_manifest)
except ImportError:
raise ImportError(
Expand Down
7 changes: 5 additions & 2 deletions sdk/python/feast/infra/compute_engines/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import logging
from typing import TYPE_CHECKING, Callable, Dict, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -162,7 +162,10 @@ def _write_batch(batch: pa.Table) -> pa.Table:
# Ray's map_batches requires a positive integer or "default" for batch_size;
# None is not accepted. When no explicit batch size is configured, omit the
# argument entirely so Ray uses its own default partitioning heuristic.
map_batches_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
map_batches_kwargs: dict[str, Any] = {
"batch_format": "pyarrow",
"zero_copy_batch": True,
}
if batch_size is not None:
map_batches_kwargs["batch_size"] = batch_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def windowed_temporal_join(
)
combined_ds = entity_windowed.union(feature_windowed)
result_ds = combined_ds.map_batches(
self._apply_windowed_point_in_time_logic,
self._apply_windowed_point_in_time_logic, # type: ignore[arg-type]
batch_format="pandas",
fn_kwargs={
"timestamp_field": timestamp_field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,14 @@ def load_ray_dataset_from_source(source: Any) -> Any:
return ray_wrapper.read_webdataset(path, **opts)

if reader_type == "huggingface":
from datasets import load_dataset

dataset_name = opts.get("dataset_name") or path
split = opts.get("split", "train")
# trust_remote_code was removed in datasets>=3.0; skip silently if present.
extra = {
k: v
for k, v in opts.items()
if k not in ("dataset_name", "split", "trust_remote_code")
}
hf_dataset = load_dataset(dataset_name, split=split, **extra)
return ray_wrapper.from_huggingface(hf_dataset)
return ray_wrapper.from_huggingface(dataset_name, split=split, **extra)

if reader_type == "mongo":
return ray_wrapper.read_mongo(
Expand Down
60 changes: 45 additions & 15 deletions sdk/python/feast/infra/ray_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,14 @@ def _connection_factory():

return ray.data.read_sql(sql, _connection_factory, **kwargs)

def from_huggingface(self, dataset: Any, **kwargs) -> Any:
"""Convert a HuggingFace datasets.Dataset to a Ray Dataset."""
return ray.data.from_huggingface(dataset, **kwargs)
def from_huggingface(
self, dataset_name: str, split: str = "train", **kwargs
) -> Any:
"""Load a HuggingFace dataset and convert to a Ray Dataset."""
from datasets import load_dataset

hf_dataset = load_dataset(dataset_name, split=split, **kwargs)
return ray.data.from_huggingface(hf_dataset)

def from_pandas(self, df: Any) -> Any:
"""Create dataset from pandas DataFrame using standard Ray."""
Expand All @@ -282,12 +287,13 @@ def __init__(
self,
cluster_name: str,
namespace: str,
auth_token: str,
auth_server: str,
auth_token: str = "",
auth_server: str = "",
skip_tls: bool = False,
enable_logging: bool = False,
num_gpus: float = 0,
worker_task_options: Optional[Dict[str, Any]] = None,
runtime_env: Optional[Dict[str, Any]] = None,
):
"""Initialize CodeFlare Ray wrapper with cluster connection parameters."""
self.cluster_name = cluster_name
Expand All @@ -298,14 +304,20 @@ def __init__(
self.enable_logging = enable_logging
self.num_gpus = num_gpus
self.worker_task_options = worker_task_options or {}
self.extra_runtime_env = runtime_env or {}
self.cluster = None

# Authenticate and setup Ray connection
self._authenticate_codeflare()
self._setup_ray_connection()

def _authenticate_codeflare(self):
"""Authenticate with CodeFlare SDK."""
"""Authenticate with CodeFlare SDK. Skipped for in-cluster pods with no explicit token."""
if not self.auth_token or not self.auth_server:
logger.info(
"No auth_token/auth_server provided; assuming in-cluster auth via service account"
)
return
try:
from codeflare_sdk import TokenAuthentication

Expand Down Expand Up @@ -339,6 +351,18 @@ def _setup_ray_connection(self):
"pip": ["feast"],
"env_vars": {"RAY_DISABLE_IMPORT_WARNING": "1"},
}
if self.extra_runtime_env:
extra_pip = self.extra_runtime_env.get("pip", [])
if extra_pip:
runtime_env["pip"] = list(
dict.fromkeys(runtime_env["pip"] + extra_pip)
)
extra_env_vars = self.extra_runtime_env.get("env_vars", {})
if extra_env_vars:
runtime_env["env_vars"].update(extra_env_vars)
for k, v in self.extra_runtime_env.items():
if k not in ("pip", "env_vars"):
runtime_env[k] = v

ray.shutdown()

Expand Down Expand Up @@ -527,23 +551,28 @@ def _factory():
remote_fn = _remote.options(**opts) if opts else _remote
return RemoteDatasetProxy(remote_fn.remote(sql, connection_url, kwargs))

def from_huggingface(self, dataset: Any, **kwargs) -> Any:
"""Convert a HuggingFace dataset - dispatched via @ray.remote to cluster workers.
def from_huggingface(
self, dataset_name: str, split: str = "train", **kwargs
) -> Any:
"""Load a HuggingFace dataset on a remote Ray worker and return a Ray Dataset.

Serialises the HuggingFace dataset and runs ray.data.from_huggingface()
on the cluster so the Ray Client driver is not involved in Ray Data ops.
The dataset is loaded directly on the worker to avoid serializing
memory-mapped HF Dataset objects across the network and to keep
driver memory usage near zero.
"""
from feast.infra.ray_shared_utils import RemoteDatasetProxy

@ray.remote
def _remote(hf_dataset, read_kwargs):
def _remote(ds_name, ds_split, extra_kwargs):
import ray
from datasets import load_dataset

return ray.data.from_huggingface(hf_dataset, **read_kwargs)
hf_dataset = load_dataset(ds_name, split=ds_split, **extra_kwargs)
return ray.data.from_huggingface(hf_dataset)

opts = self._get_task_options()
remote_fn = _remote.options(**opts) if opts else _remote
return RemoteDatasetProxy(remote_fn.remote(dataset, kwargs))
return RemoteDatasetProxy(remote_fn.remote(dataset_name, split, kwargs))

def from_pandas(self, df: Any) -> Any:
"""Create dataset from pandas DataFrame - dispatched via @ray.remote."""
Expand Down Expand Up @@ -764,12 +793,13 @@ def _initialize_kuberay(config: Any, enable_logging: bool = False) -> None:
_ray_wrapper = CodeFlareRayWrapper(
cluster_name=kuberay_config["cluster_name"],
namespace=kuberay_config["namespace"],
auth_token=kuberay_config["auth_token"],
auth_server=kuberay_config["auth_server"],
auth_token=kuberay_config.get("auth_token", ""),
auth_server=kuberay_config.get("auth_server", ""),
skip_tls=kuberay_config.get("skip_tls", False),
enable_logging=enable_logging,
num_gpus=getattr(config, "num_gpus", 0) or 0,
worker_task_options=getattr(config, "worker_task_options", None),
runtime_env=kuberay_config.get("runtime_env"),
)

logger.info("KubeRay cluster connection established via CodeFlare SDK")
Expand Down
Loading
Loading