Skip to content

Commit 1d707e6

Browse files
authored
Feat/progress (openml#1335)
* Add progress bar to downloading minio files * Do not redownload cached files There is now a way to force a cache clear, so always redownloading is not useful anymore. * Set typed values on dictionary to avoid TypeError from Config * Add regression test for parsing booleans
1 parent b4d038f commit 1d707e6

9 files changed

Lines changed: 125 additions & 15 deletions

File tree

doc/progress.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ Changelog
99
next
1010
~~~~~~
1111

12+
* ADD #1335: Improve MinIO support.
13+
* Add progress bar for downloading MinIO files. Enable it with setting `show_progress` to true on either `openml.config` or the configuration file.
14+
* When using `download_all_files`, files are only downloaded if they do not yet exist in the cache.
1215
* MAINT #1340: Add Numpy 2.0 support. Update tests to work with scikit-learn <= 1.5.
1316
* ADD #1342: Add HTTP header to requests to indicate they are from openml-python.
1417

examples/20_basic/simple_datasets_tutorial.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@
5050
X, y, categorical_indicator, attribute_names = dataset.get_data(
5151
dataset_format="dataframe", target=dataset.default_target_attribute
5252
)
53+
54+
############################################################################
55+
# Tip: you can get a progress bar for dataset downloads, simply set it in
56+
# the configuration. Either in code or in the configuration file
57+
# (see also the introduction tutorial)
58+
59+
openml.config.show_progress = True
60+
61+
5362
############################################################################
5463
# Visualize the dataset
5564
# =====================

openml/_api_calls.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4+
import contextlib
45
import hashlib
56
import logging
67
import math
@@ -26,6 +27,7 @@
2627
OpenMLServerException,
2728
OpenMLServerNoResult,
2829
)
30+
from .utils import ProgressBar
2931

3032
_HEADERS = {"user-agent": f"openml-python/{__version__}"}
3133

@@ -161,12 +163,12 @@ def _download_minio_file(
161163
proxy_client = ProxyManager(proxy) if proxy else None
162164

163165
client = minio.Minio(endpoint=parsed_url.netloc, secure=False, http_client=proxy_client)
164-
165166
try:
166167
client.fget_object(
167168
bucket_name=bucket,
168169
object_name=object_name,
169170
file_path=str(destination),
171+
progress=ProgressBar() if config.show_progress else None,
170172
request_headers=_HEADERS,
171173
)
172174
if destination.is_file() and destination.suffix == ".zip":
@@ -206,11 +208,12 @@ def _download_minio_bucket(source: str, destination: str | Path) -> None:
206208
if file_object.object_name is None:
207209
raise ValueError("Object name is None.")
208210

209-
_download_minio_file(
210-
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
211-
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
212-
exists_ok=True,
213-
)
211+
with contextlib.suppress(FileExistsError): # Simply use cached version instead
212+
_download_minio_file(
213+
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
214+
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
215+
exists_ok=False,
216+
)
214217

215218

216219
def _download_text_file(

openml/config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class _Config(TypedDict):
2828
avoid_duplicate_runs: bool
2929
retry_policy: Literal["human", "robot"]
3030
connection_n_retries: int
31+
show_progress: bool
3132

3233

3334
def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002
@@ -111,6 +112,7 @@ def set_file_log_level(file_output_level: int) -> None:
111112
"avoid_duplicate_runs": True,
112113
"retry_policy": "human",
113114
"connection_n_retries": 5,
115+
"show_progress": False,
114116
}
115117

116118
# Default values are actually added here in the _setup() function which is
@@ -131,6 +133,7 @@ def get_server_base_url() -> str:
131133

132134

133135
apikey: str = _defaults["apikey"]
136+
show_progress: bool = _defaults["show_progress"]
134137
# The current cache directory (without the server name)
135138
_root_cache_directory = Path(_defaults["cachedir"])
136139
avoid_duplicate_runs = _defaults["avoid_duplicate_runs"]
@@ -238,6 +241,7 @@ def _setup(config: _Config | None = None) -> None:
238241
global server # noqa: PLW0603
239242
global _root_cache_directory # noqa: PLW0603
240243
global avoid_duplicate_runs # noqa: PLW0603
244+
global show_progress # noqa: PLW0603
241245

242246
config_file = determine_config_file_path()
243247
config_dir = config_file.parent
@@ -255,6 +259,7 @@ def _setup(config: _Config | None = None) -> None:
255259
avoid_duplicate_runs = config["avoid_duplicate_runs"]
256260
apikey = config["apikey"]
257261
server = config["server"]
262+
show_progress = config["show_progress"]
258263
short_cache_dir = Path(config["cachedir"])
259264
n_retries = int(config["connection_n_retries"])
260265

@@ -328,11 +333,11 @@ def _parse_config(config_file: str | Path) -> _Config:
328333
logger.info("Error opening file %s: %s", config_file, e.args[0])
329334
config_file_.seek(0)
330335
config.read_file(config_file_)
331-
if isinstance(config["FAKE_SECTION"]["avoid_duplicate_runs"], str):
332-
config["FAKE_SECTION"]["avoid_duplicate_runs"] = config["FAKE_SECTION"].getboolean(
333-
"avoid_duplicate_runs"
334-
) # type: ignore
335-
return dict(config.items("FAKE_SECTION")) # type: ignore
336+
configuration = dict(config.items("FAKE_SECTION"))
337+
for boolean_field in ["avoid_duplicate_runs", "show_progress"]:
338+
if isinstance(config["FAKE_SECTION"][boolean_field], str):
339+
configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore
340+
return configuration # type: ignore
336341

337342

338343
def get_config_as_dict() -> _Config:
@@ -343,6 +348,7 @@ def get_config_as_dict() -> _Config:
343348
"avoid_duplicate_runs": avoid_duplicate_runs,
344349
"connection_n_retries": connection_n_retries,
345350
"retry_policy": retry_policy,
351+
"show_progress": show_progress,
346352
}
347353

348354

openml/datasets/functions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,10 +1262,9 @@ def _get_dataset_parquet(
12621262
if old_file_path.is_file():
12631263
old_file_path.rename(output_file_path)
12641264

1265-
# For this release, we want to be able to force a new download even if the
1266-
# parquet file is already present when ``download_all_files`` is set.
1267-
# For now, it would be the only way for the user to fetch the additional
1268-
# files in the bucket (no function exists on an OpenMLDataset to do this).
1265+
# The call below skips files already on disk, so avoids downloading the parquet file twice.
1266+
# To force the old behavior of always downloading everything, use `force_refresh_cache`
1267+
# of `get_dataset`
12691268
if download_all_files:
12701269
openml._api_calls._download_minio_bucket(source=url, destination=cache_directory)
12711270

openml/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import numpy as np
1313
import pandas as pd
1414
import xmltodict
15+
from minio.helpers import ProgressType
16+
from tqdm import tqdm
1517

1618
import openml
1719
import openml._api_calls
@@ -471,3 +473,39 @@ def _create_lockfiles_dir() -> Path:
471473
with contextlib.suppress(OSError):
472474
path.mkdir(exist_ok=True, parents=True)
473475
return path
476+
477+
478+
class ProgressBar(ProgressType):
479+
"""Progressbar for MinIO function's `progress` parameter."""
480+
481+
def __init__(self) -> None:
482+
self._object_name = ""
483+
self._progress_bar: tqdm | None = None
484+
485+
def set_meta(self, object_name: str, total_length: int) -> None:
486+
"""Initializes the progress bar.
487+
488+
Parameters
489+
----------
490+
object_name: str
491+
Not used.
492+
493+
total_length: int
494+
File size of the object in bytes.
495+
"""
496+
self._object_name = object_name
497+
self._progress_bar = tqdm(total=total_length, unit_scale=True, unit="B")
498+
499+
def update(self, length: int) -> None:
500+
"""Updates the progress bar.
501+
502+
Parameters
503+
----------
504+
length: int
505+
Number of bytes downloaded since last `update` call.
506+
"""
507+
if not self._progress_bar:
508+
raise RuntimeError("Call `set_meta` before calling `update`.")
509+
self._progress_bar.update(length)
510+
if self._progress_bar.total <= self._progress_bar.n:
511+
self._progress_bar.close()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"numpy>=1.6.2",
2020
"minio",
2121
"pyarrow",
22+
"tqdm", # For MinIO download progress bars
2223
"packaging",
2324
]
2425
requires-python = ">=3.8"

tests/test_openml/test_api_calls.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

33
import unittest.mock
4+
from pathlib import Path
5+
from typing import NamedTuple, Iterable, Iterator
6+
from unittest import mock
47

8+
import minio
59
import pytest
610

711
import openml
812
import openml.testing
13+
from openml._api_calls import _download_minio_bucket
914

1015

1116
class TestConfig(openml.testing.TestBase):
@@ -30,3 +35,39 @@ def test_retry_on_database_error(self, Session_class_mock, _):
3035
openml._api_calls._send_request("get", "/abc", {})
3136

3237
assert Session_class_mock.return_value.__enter__.return_value.get.call_count == 20
38+
39+
class FakeObject(NamedTuple):
40+
object_name: str
41+
42+
class FakeMinio:
43+
def __init__(self, objects: Iterable[FakeObject] | None = None):
44+
self._objects = objects or []
45+
46+
def list_objects(self, *args, **kwargs) -> Iterator[FakeObject]:
47+
yield from self._objects
48+
49+
def fget_object(self, object_name: str, file_path: str, *args, **kwargs) -> None:
50+
if object_name in [obj.object_name for obj in self._objects]:
51+
Path(file_path).write_text("foo")
52+
return
53+
raise FileNotFoundError
54+
55+
56+
@mock.patch.object(minio, "Minio")
57+
def test_download_all_files_observes_cache(mock_minio, tmp_path: Path) -> None:
58+
some_prefix, some_filename = "some/prefix", "dataset.arff"
59+
some_object_path = f"{some_prefix}/{some_filename}"
60+
some_url = f"https://not.real.com/bucket/{some_object_path}"
61+
mock_minio.return_value = FakeMinio(
62+
objects=[
63+
FakeObject(some_object_path),
64+
],
65+
)
66+
67+
_download_minio_bucket(source=some_url, destination=tmp_path)
68+
time_created = (tmp_path / "dataset.arff").stat().st_ctime
69+
70+
_download_minio_bucket(source=some_url, destination=tmp_path)
71+
time_modified = (tmp_path / some_filename).stat().st_mtime
72+
73+
assert time_created == time_modified

tests/test_openml/test_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,13 @@ def test_configuration_file_not_overwritten_on_load():
133133

134134
assert config_file_content == new_file_content
135135
assert "abcd" == read_config["apikey"]
136+
137+
def test_configuration_loads_booleans(tmp_path):
138+
config_file_content = "avoid_duplicate_runs=true\nshow_progress=false"
139+
with (tmp_path/"config").open("w") as config_file:
140+
config_file.write(config_file_content)
141+
read_config = openml.config._parse_config(tmp_path)
142+
143+
# Explicit test to avoid truthy/falsy modes of other types
144+
assert True == read_config["avoid_duplicate_runs"]
145+
assert False == read_config["show_progress"]

0 commit comments

Comments
 (0)