@@ -345,9 +345,10 @@ def _download_data(self) -> None:
345345 # import required here to avoid circular import.
346346 from .functions import _get_dataset_arff , _get_dataset_parquet
347347
348- self .data_file = str (_get_dataset_arff (self ))
349348 if self ._parquet_url is not None :
350349 self .parquet_file = str (_get_dataset_parquet (self ))
350+ if self .parquet_file is None :
351+ self .data_file = str (_get_dataset_arff (self ))
351352
352353 def _get_arff (self , format : str ) -> dict : # noqa: A002
353354 """Read ARFF file and return decoded arff.
@@ -535,18 +536,7 @@ def _cache_compressed_file_from_file(
535536 feather_attribute_file ,
536537 ) = self ._compressed_cache_file_paths (data_file )
537538
538- if data_file .suffix == ".arff" :
539- data , categorical , attribute_names = self ._parse_data_from_arff (data_file )
540- elif data_file .suffix == ".pq" :
541- try :
542- data = pd .read_parquet (data_file )
543- except Exception as e : # noqa: BLE001
544- raise Exception (f"File: { data_file } " ) from e
545-
546- categorical = [data [c ].dtype .name == "category" for c in data .columns ]
547- attribute_names = list (data .columns )
548- else :
549- raise ValueError (f"Unknown file type for file '{ data_file } '." )
539+ attribute_names , categorical , data = self ._parse_data_from_file (data_file )
550540
551541 # Feather format does not work for sparse datasets, so we use pickle for sparse datasets
552542 if scipy .sparse .issparse (data ):
@@ -572,6 +562,24 @@ def _cache_compressed_file_from_file(
572562
573563 return data , categorical , attribute_names
574564
565+ def _parse_data_from_file (self , data_file : Path ) -> tuple [list [str ], list [bool ], pd .DataFrame ]:
566+ if data_file .suffix == ".arff" :
567+ data , categorical , attribute_names = self ._parse_data_from_arff (data_file )
568+ elif data_file .suffix == ".pq" :
569+ attribute_names , categorical , data = self ._parse_data_from_pq (data_file )
570+ else :
571+ raise ValueError (f"Unknown file type for file '{ data_file } '." )
572+ return attribute_names , categorical , data
573+
574+ def _parse_data_from_pq (self , data_file : Path ) -> tuple [list [str ], list [bool ], pd .DataFrame ]:
575+ try :
576+ data = pd .read_parquet (data_file )
577+ except Exception as e : # noqa: BLE001
578+ raise Exception (f"File: { data_file } " ) from e
579+ categorical = [data [c ].dtype .name == "category" for c in data .columns ]
580+ attribute_names = list (data .columns )
581+ return attribute_names , categorical , data
582+
575583 def _load_data (self ) -> tuple [pd .DataFrame | scipy .sparse .csr_matrix , list [bool ], list [str ]]: # noqa: PLR0912, C901
576584 """Load data from compressed format or arff. Download data if not present on disk."""
577585 need_to_create_pickle = self .cache_format == "pickle" and self .data_pickle_file is None
@@ -636,8 +644,10 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool]
636644 "Please manually delete the cache file if you want OpenML-Python "
637645 "to attempt to reconstruct it." ,
638646 )
639- assert self .data_file is not None
640- data , categorical , attribute_names = self ._parse_data_from_arff (Path (self .data_file ))
647+ file_to_load = self .data_file if self .parquet_file is None else self .parquet_file
648+ assert file_to_load is not None
649+ attr , cat , df = self ._parse_data_from_file (Path (file_to_load ))
650+ return df , cat , attr
641651
642652 data_up_to_date = isinstance (data , pd .DataFrame ) or scipy .sparse .issparse (data )
643653 if self .cache_format == "pickle" and not data_up_to_date :
0 commit comments