Skip to content

Commit b4d038f

Browse files
authored
Lazy arff (openml#1346)
* Prefer parquet over arff, do not load arff if not needed * Only download arff if needed * Test arff file is not set when downloading parquet from prod
1 parent fa7e9db commit b4d038f

3 files changed

Lines changed: 33 additions & 19 deletions

File tree

openml/datasets/dataset.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,10 @@ def _download_data(self) -> None:
345345
# import required here to avoid circular import.
346346
from .functions import _get_dataset_arff, _get_dataset_parquet
347347

348-
self.data_file = str(_get_dataset_arff(self))
349348
if self._parquet_url is not None:
350349
self.parquet_file = str(_get_dataset_parquet(self))
350+
if self.parquet_file is None:
351+
self.data_file = str(_get_dataset_arff(self))
351352

352353
def _get_arff(self, format: str) -> dict: # noqa: A002
353354
"""Read ARFF file and return decoded arff.
@@ -535,18 +536,7 @@ def _cache_compressed_file_from_file(
535536
feather_attribute_file,
536537
) = self._compressed_cache_file_paths(data_file)
537538

538-
if data_file.suffix == ".arff":
539-
data, categorical, attribute_names = self._parse_data_from_arff(data_file)
540-
elif data_file.suffix == ".pq":
541-
try:
542-
data = pd.read_parquet(data_file)
543-
except Exception as e: # noqa: BLE001
544-
raise Exception(f"File: {data_file}") from e
545-
546-
categorical = [data[c].dtype.name == "category" for c in data.columns]
547-
attribute_names = list(data.columns)
548-
else:
549-
raise ValueError(f"Unknown file type for file '{data_file}'.")
539+
attribute_names, categorical, data = self._parse_data_from_file(data_file)
550540

551541
# Feather format does not work for sparse datasets, so we use pickle for sparse datasets
552542
if scipy.sparse.issparse(data):
@@ -572,6 +562,24 @@ def _cache_compressed_file_from_file(
572562

573563
return data, categorical, attribute_names
574564

565+
def _parse_data_from_file(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]:
566+
if data_file.suffix == ".arff":
567+
data, categorical, attribute_names = self._parse_data_from_arff(data_file)
568+
elif data_file.suffix == ".pq":
569+
attribute_names, categorical, data = self._parse_data_from_pq(data_file)
570+
else:
571+
raise ValueError(f"Unknown file type for file '{data_file}'.")
572+
return attribute_names, categorical, data
573+
574+
def _parse_data_from_pq(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]:
575+
try:
576+
data = pd.read_parquet(data_file)
577+
except Exception as e: # noqa: BLE001
578+
raise Exception(f"File: {data_file}") from e
579+
categorical = [data[c].dtype.name == "category" for c in data.columns]
580+
attribute_names = list(data.columns)
581+
return attribute_names, categorical, data
582+
575583
def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool], list[str]]: # noqa: PLR0912, C901
576584
"""Load data from compressed format or arff. Download data if not present on disk."""
577585
need_to_create_pickle = self.cache_format == "pickle" and self.data_pickle_file is None
@@ -636,8 +644,10 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool]
636644
"Please manually delete the cache file if you want OpenML-Python "
637645
"to attempt to reconstruct it.",
638646
)
639-
assert self.data_file is not None
640-
data, categorical, attribute_names = self._parse_data_from_arff(Path(self.data_file))
647+
file_to_load = self.data_file if self.parquet_file is None else self.parquet_file
648+
assert file_to_load is not None
649+
attr, cat, df = self._parse_data_from_file(Path(file_to_load))
650+
return df, cat, attr
641651

642652
data_up_to_date = isinstance(data, pd.DataFrame) or scipy.sparse.issparse(data)
643653
if self.cache_format == "pickle" and not data_up_to_date:

openml/datasets/functions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def get_datasets(
450450

451451

452452
@openml.utils.thread_safe_if_oslo_installed
453-
def get_dataset( # noqa: C901, PLR0912
453+
def get_dataset( # noqa: C901, PLR0912, PLR0915
454454
dataset_id: int | str,
455455
download_data: bool | None = None, # Optional for deprecation warning; later again only bool
456456
version: int | None = None,
@@ -589,7 +589,6 @@ def get_dataset( # noqa: C901, PLR0912
589589
if download_qualities:
590590
qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id)
591591

592-
arff_file = _get_dataset_arff(description) if download_data else None
593592
if "oml:parquet_url" in description and download_data:
594593
try:
595594
parquet_file = _get_dataset_parquet(
@@ -598,10 +597,14 @@ def get_dataset( # noqa: C901, PLR0912
598597
)
599598
except urllib3.exceptions.MaxRetryError:
600599
parquet_file = None
601-
if parquet_file is None and arff_file:
602-
logger.warning("Failed to download parquet, fallback on ARFF.")
603600
else:
604601
parquet_file = None
602+
603+
arff_file = None
604+
if parquet_file is None and download_data:
605+
logger.warning("Failed to download parquet, fallback on ARFF.")
606+
arff_file = _get_dataset_arff(description)
607+
605608
remove_dataset_cache = False
606609
except OpenMLServerException as e:
607610
# if there was an exception

tests/test_datasets/test_dataset_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,7 @@ def test_get_dataset_parquet(self):
15741574
assert dataset._parquet_url is not None
15751575
assert dataset.parquet_file is not None
15761576
assert os.path.isfile(dataset.parquet_file)
1577+
assert dataset.data_file is None # is alias for arff path
15771578

15781579
@pytest.mark.production()
15791580
def test_list_datasets_with_high_size_parameter(self):

0 commit comments

Comments
 (0)