@@ -262,9 +262,14 @@ def _connection_factory():
262262
263263 return ray .data .read_sql (sql , _connection_factory , ** kwargs )
264264
265- def from_huggingface (self , dataset : Any , ** kwargs ) -> Any :
266- """Convert a HuggingFace datasets.Dataset to a Ray Dataset."""
267- return ray .data .from_huggingface (dataset , ** kwargs )
265+ def from_huggingface (
266+ self , dataset_name : str , split : str = "train" , ** kwargs
267+ ) -> Any :
268+ """Load a HuggingFace dataset and convert to a Ray Dataset."""
269+ from datasets import load_dataset
270+
271+ hf_dataset = load_dataset (dataset_name , split = split , ** kwargs )
272+ return ray .data .from_huggingface (hf_dataset )
268273
269274 def from_pandas (self , df : Any ) -> Any :
270275 """Create dataset from pandas DataFrame using standard Ray."""
@@ -282,12 +287,13 @@ def __init__(
282287 self ,
283288 cluster_name : str ,
284289 namespace : str ,
285- auth_token : str ,
286- auth_server : str ,
290+ auth_token : str = "" ,
291+ auth_server : str = "" ,
287292 skip_tls : bool = False ,
288293 enable_logging : bool = False ,
289294 num_gpus : float = 0 ,
290295 worker_task_options : Optional [Dict [str , Any ]] = None ,
296+ runtime_env : Optional [Dict [str , Any ]] = None ,
291297 ):
292298 """Initialize CodeFlare Ray wrapper with cluster connection parameters."""
293299 self .cluster_name = cluster_name
@@ -298,14 +304,20 @@ def __init__(
298304 self .enable_logging = enable_logging
299305 self .num_gpus = num_gpus
300306 self .worker_task_options = worker_task_options or {}
307+ self .extra_runtime_env = runtime_env or {}
301308 self .cluster = None
302309
303310 # Authenticate and setup Ray connection
304311 self ._authenticate_codeflare ()
305312 self ._setup_ray_connection ()
306313
307314 def _authenticate_codeflare (self ):
308- """Authenticate with CodeFlare SDK."""
315+ """Authenticate with CodeFlare SDK. Skipped for in-cluster pods with no explicit token."""
316+ if not self .auth_token or not self .auth_server :
317+ logger .info (
318+ "No auth_token/auth_server provided; assuming in-cluster auth via service account"
319+ )
320+ return
309321 try :
310322 from codeflare_sdk import TokenAuthentication
311323
@@ -339,6 +351,18 @@ def _setup_ray_connection(self):
339351 "pip" : ["feast" ],
340352 "env_vars" : {"RAY_DISABLE_IMPORT_WARNING" : "1" },
341353 }
354+ if self .extra_runtime_env :
355+ extra_pip = self .extra_runtime_env .get ("pip" , [])
356+ if extra_pip :
357+ runtime_env ["pip" ] = list (
358+ dict .fromkeys (runtime_env ["pip" ] + extra_pip )
359+ )
360+ extra_env_vars = self .extra_runtime_env .get ("env_vars" , {})
361+ if extra_env_vars :
362+ runtime_env ["env_vars" ].update (extra_env_vars )
363+ for k , v in self .extra_runtime_env .items ():
364+ if k not in ("pip" , "env_vars" ):
365+ runtime_env [k ] = v
342366
343367 ray .shutdown ()
344368
@@ -527,23 +551,28 @@ def _factory():
527551 remote_fn = _remote .options (** opts ) if opts else _remote
528552 return RemoteDatasetProxy (remote_fn .remote (sql , connection_url , kwargs ))
529553
530- def from_huggingface (self , dataset : Any , ** kwargs ) -> Any :
531- """Convert a HuggingFace dataset - dispatched via @ray.remote to cluster workers.
554+ def from_huggingface (
555+ self , dataset_name : str , split : str = "train" , ** kwargs
556+ ) -> Any :
557+ """Load a HuggingFace dataset on a remote Ray worker and return a Ray Dataset.
532558
533- Serialises the HuggingFace dataset and runs ray.data.from_huggingface()
534- on the cluster so the Ray Client driver is not involved in Ray Data ops.
559+ The dataset is loaded directly on the worker to avoid serializing
560+ memory-mapped HF Dataset objects across the network and to keep
561+ driver memory usage near zero.
535562 """
536563 from feast .infra .ray_shared_utils import RemoteDatasetProxy
537564
538565 @ray .remote
539- def _remote (hf_dataset , read_kwargs ):
566+ def _remote (ds_name , ds_split , extra_kwargs ):
540567 import ray
568+ from datasets import load_dataset
541569
542- return ray .data .from_huggingface (hf_dataset , ** read_kwargs )
570+ hf_dataset = load_dataset (ds_name , split = ds_split , ** extra_kwargs )
571+ return ray .data .from_huggingface (hf_dataset )
543572
544573 opts = self ._get_task_options ()
545574 remote_fn = _remote .options (** opts ) if opts else _remote
546- return RemoteDatasetProxy (remote_fn .remote (dataset , kwargs ))
575+ return RemoteDatasetProxy (remote_fn .remote (dataset_name , split , kwargs ))
547576
548577 def from_pandas (self , df : Any ) -> Any :
549578 """Create dataset from pandas DataFrame - dispatched via @ray.remote."""
@@ -764,12 +793,13 @@ def _initialize_kuberay(config: Any, enable_logging: bool = False) -> None:
764793 _ray_wrapper = CodeFlareRayWrapper (
765794 cluster_name = kuberay_config ["cluster_name" ],
766795 namespace = kuberay_config ["namespace" ],
767- auth_token = kuberay_config [ "auth_token" ] ,
768- auth_server = kuberay_config [ "auth_server" ] ,
796+ auth_token = kuberay_config . get ( "auth_token" , "" ) ,
797+ auth_server = kuberay_config . get ( "auth_server" , "" ) ,
769798 skip_tls = kuberay_config .get ("skip_tls" , False ),
770799 enable_logging = enable_logging ,
771800 num_gpus = getattr (config , "num_gpus" , 0 ) or 0 ,
772801 worker_task_options = getattr (config , "worker_task_options" , None ),
802+ runtime_env = kuberay_config .get ("runtime_env" ),
773803 )
774804
775805 logger .info ("KubeRay cluster connection established via CodeFlare SDK" )
0 commit comments