Skip to content

Commit 7f592a4

Browse files
committed
fix: Fixes for ray source
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent 753dee5 commit 7f592a4

22 files changed

Lines changed: 9624 additions & 8663 deletions

pixi.lock

Lines changed: 540 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ rag = [
120120
]
121121
ray = [
122122
'ray>=2.47.0; python_version == "3.10"',
123-
'codeflare-sdk>=0.31.1; python_version > "3.10"'
123+
'codeflare-sdk>=0.31.1; python_version > "3.10"',
124+
"datasets>=3.6.0",
124125
]
125126
redis = [
126127
"redis>=4.2.2,<8",

sdk/python/feast/dbt/parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def parse(self) -> None:
107107
try:
108108
from dbt_artifacts_parser.parser import parse_manifest
109109

110+
assert self._raw_manifest is not None
110111
self._parsed_manifest = parse_manifest(manifest=self._raw_manifest)
111112
except ImportError:
112113
raise ImportError(

sdk/python/feast/infra/compute_engines/ray/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import logging
6-
from typing import TYPE_CHECKING, Callable, Dict, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, Union
77

88
import numpy as np
99
import pandas as pd
@@ -162,7 +162,10 @@ def _write_batch(batch: pa.Table) -> pa.Table:
162162
# Ray's map_batches requires a positive integer or "default" for batch_size;
163163
# None is not accepted. When no explicit batch size is configured, omit the
164164
# argument entirely so Ray uses its own default partitioning heuristic.
165-
map_batches_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
165+
map_batches_kwargs: dict[str, Any] = {
166+
"batch_format": "pyarrow",
167+
"zero_copy_batch": True,
168+
}
166169
if batch_size is not None:
167170
map_batches_kwargs["batch_size"] = batch_size
168171

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def windowed_temporal_join(
872872
)
873873
combined_ds = entity_windowed.union(feature_windowed)
874874
result_ds = combined_ds.map_batches(
875-
self._apply_windowed_point_in_time_logic,
875+
self._apply_windowed_point_in_time_logic, # type: ignore[arg-type]
876876
batch_format="pandas",
877877
fn_kwargs={
878878
"timestamp_field": timestamp_field,

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray_offline_store_reader.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,14 @@ def load_ray_dataset_from_source(source: Any) -> Any:
8585
return ray_wrapper.read_webdataset(path, **opts)
8686

8787
if reader_type == "huggingface":
88-
from datasets import load_dataset
89-
9088
dataset_name = opts.get("dataset_name") or path
9189
split = opts.get("split", "train")
92-
# trust_remote_code was removed in datasets>=3.0; skip silently if present.
9390
extra = {
9491
k: v
9592
for k, v in opts.items()
9693
if k not in ("dataset_name", "split", "trust_remote_code")
9794
}
98-
hf_dataset = load_dataset(dataset_name, split=split, **extra)
99-
return ray_wrapper.from_huggingface(hf_dataset)
95+
return ray_wrapper.from_huggingface(dataset_name, split=split, **extra)
10096

10197
if reader_type == "mongo":
10298
return ray_wrapper.read_mongo(

sdk/python/feast/infra/ray_initializer.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,14 @@ def _connection_factory():
262262

263263
return ray.data.read_sql(sql, _connection_factory, **kwargs)
264264

265-
def from_huggingface(self, dataset: Any, **kwargs) -> Any:
266-
"""Convert a HuggingFace datasets.Dataset to a Ray Dataset."""
267-
return ray.data.from_huggingface(dataset, **kwargs)
265+
def from_huggingface(
266+
self, dataset_name: str, split: str = "train", **kwargs
267+
) -> Any:
268+
"""Load a HuggingFace dataset and convert to a Ray Dataset."""
269+
from datasets import load_dataset
270+
271+
hf_dataset = load_dataset(dataset_name, split=split, **kwargs)
272+
return ray.data.from_huggingface(hf_dataset)
268273

269274
def from_pandas(self, df: Any) -> Any:
270275
"""Create dataset from pandas DataFrame using standard Ray."""
@@ -282,12 +287,13 @@ def __init__(
282287
self,
283288
cluster_name: str,
284289
namespace: str,
285-
auth_token: str,
286-
auth_server: str,
290+
auth_token: str = "",
291+
auth_server: str = "",
287292
skip_tls: bool = False,
288293
enable_logging: bool = False,
289294
num_gpus: float = 0,
290295
worker_task_options: Optional[Dict[str, Any]] = None,
296+
runtime_env: Optional[Dict[str, Any]] = None,
291297
):
292298
"""Initialize CodeFlare Ray wrapper with cluster connection parameters."""
293299
self.cluster_name = cluster_name
@@ -298,14 +304,20 @@ def __init__(
298304
self.enable_logging = enable_logging
299305
self.num_gpus = num_gpus
300306
self.worker_task_options = worker_task_options or {}
307+
self.extra_runtime_env = runtime_env or {}
301308
self.cluster = None
302309

303310
# Authenticate and setup Ray connection
304311
self._authenticate_codeflare()
305312
self._setup_ray_connection()
306313

307314
def _authenticate_codeflare(self):
308-
"""Authenticate with CodeFlare SDK."""
315+
"""Authenticate with CodeFlare SDK. Skipped for in-cluster pods with no explicit token."""
316+
if not self.auth_token or not self.auth_server:
317+
logger.info(
318+
"No auth_token/auth_server provided; assuming in-cluster auth via service account"
319+
)
320+
return
309321
try:
310322
from codeflare_sdk import TokenAuthentication
311323

@@ -339,6 +351,18 @@ def _setup_ray_connection(self):
339351
"pip": ["feast"],
340352
"env_vars": {"RAY_DISABLE_IMPORT_WARNING": "1"},
341353
}
354+
if self.extra_runtime_env:
355+
extra_pip = self.extra_runtime_env.get("pip", [])
356+
if extra_pip:
357+
runtime_env["pip"] = list(
358+
dict.fromkeys(runtime_env["pip"] + extra_pip)
359+
)
360+
extra_env_vars = self.extra_runtime_env.get("env_vars", {})
361+
if extra_env_vars:
362+
runtime_env["env_vars"].update(extra_env_vars)
363+
for k, v in self.extra_runtime_env.items():
364+
if k not in ("pip", "env_vars"):
365+
runtime_env[k] = v
342366

343367
ray.shutdown()
344368

@@ -527,23 +551,28 @@ def _factory():
527551
remote_fn = _remote.options(**opts) if opts else _remote
528552
return RemoteDatasetProxy(remote_fn.remote(sql, connection_url, kwargs))
529553

530-
def from_huggingface(self, dataset: Any, **kwargs) -> Any:
531-
"""Convert a HuggingFace dataset - dispatched via @ray.remote to cluster workers.
554+
def from_huggingface(
555+
self, dataset_name: str, split: str = "train", **kwargs
556+
) -> Any:
557+
"""Load a HuggingFace dataset on a remote Ray worker and return a Ray Dataset.
532558
533-
Serialises the HuggingFace dataset and runs ray.data.from_huggingface()
534-
on the cluster so the Ray Client driver is not involved in Ray Data ops.
559+
The dataset is loaded directly on the worker to avoid serializing
560+
memory-mapped HF Dataset objects across the network and to keep
561+
driver memory usage near zero.
535562
"""
536563
from feast.infra.ray_shared_utils import RemoteDatasetProxy
537564

538565
@ray.remote
539-
def _remote(hf_dataset, read_kwargs):
566+
def _remote(ds_name, ds_split, extra_kwargs):
540567
import ray
568+
from datasets import load_dataset
541569

542-
return ray.data.from_huggingface(hf_dataset, **read_kwargs)
570+
hf_dataset = load_dataset(ds_name, split=ds_split, **extra_kwargs)
571+
return ray.data.from_huggingface(hf_dataset)
543572

544573
opts = self._get_task_options()
545574
remote_fn = _remote.options(**opts) if opts else _remote
546-
return RemoteDatasetProxy(remote_fn.remote(dataset, kwargs))
575+
return RemoteDatasetProxy(remote_fn.remote(dataset_name, split, kwargs))
547576

548577
def from_pandas(self, df: Any) -> Any:
549578
"""Create dataset from pandas DataFrame - dispatched via @ray.remote."""
@@ -764,12 +793,13 @@ def _initialize_kuberay(config: Any, enable_logging: bool = False) -> None:
764793
_ray_wrapper = CodeFlareRayWrapper(
765794
cluster_name=kuberay_config["cluster_name"],
766795
namespace=kuberay_config["namespace"],
767-
auth_token=kuberay_config["auth_token"],
768-
auth_server=kuberay_config["auth_server"],
796+
auth_token=kuberay_config.get("auth_token", ""),
797+
auth_server=kuberay_config.get("auth_server", ""),
769798
skip_tls=kuberay_config.get("skip_tls", False),
770799
enable_logging=enable_logging,
771800
num_gpus=getattr(config, "num_gpus", 0) or 0,
772801
worker_task_options=getattr(config, "worker_task_options", None),
802+
runtime_env=kuberay_config.get("runtime_env"),
773803
)
774804

775805
logger.info("KubeRay cluster connection established via CodeFlare SDK")

0 commit comments

Comments
 (0)