Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update based on Pieter's feedback
  • Loading branch information
mfeurer committed Jun 13, 2023
commit 8e2de470e3fe4e87e31be3495dc23c8c2a1d2b04
34 changes: 17 additions & 17 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _create_log_handlers(create_file_handler=True):

if create_file_handler:
one_mb = 2**20
log_path = os.path.join(_cache_directory, "openml_python.log")
log_path = os.path.join(_root_cache_directory, "openml_python.log")
file_handler = logging.handlers.RotatingFileHandler(
log_path, maxBytes=one_mb, backupCount=1, delay=True
)
Expand Down Expand Up @@ -125,7 +125,7 @@ def get_server_base_url() -> str:

apikey = _defaults["apikey"]
# The current cache directory (without the server name)
_cache_directory = str(_defaults["cachedir"]) # so mypy knows it is a string
_root_cache_directory = str(_defaults["cachedir"]) # so mypy knows it is a string
avoid_duplicate_runs = True if _defaults["avoid_duplicate_runs"] == "True" else False

retry_policy = _defaults["retry_policy"]
Expand Down Expand Up @@ -226,7 +226,7 @@ def _setup(config=None):
"""
global apikey
global server
global _cache_directory
global _root_cache_directory
global avoid_duplicate_runs

config_file = determine_config_file_path()
Expand Down Expand Up @@ -266,15 +266,15 @@ def _get(config, key):

set_retry_policy(_get(config, "retry_policy"), n_retries)

_cache_directory = os.path.expanduser(short_cache_dir)
_root_cache_directory = os.path.expanduser(short_cache_dir)
# create the cache subdirectory
if not os.path.exists(_cache_directory):
if not os.path.exists(_root_cache_directory):
try:
os.makedirs(_cache_directory, exist_ok=True)
os.makedirs(_root_cache_directory, exist_ok=True)
except PermissionError:
openml_logger.warning(
"No permission to create openml cache directory at %s! This can result in "
"OpenML-Python not working properly." % _cache_directory
"OpenML-Python not working properly." % _root_cache_directory
)

if cache_exists:
Expand Down Expand Up @@ -333,7 +333,7 @@ def get_config_as_dict():
config = dict()
config["apikey"] = apikey
config["server"] = server
config["cachedir"] = _cache_directory
config["cachedir"] = _root_cache_directory
config["avoid_duplicate_runs"] = avoid_duplicate_runs
config["connection_n_retries"] = connection_n_retries
config["retry_policy"] = retry_policy
Expand Down Expand Up @@ -362,19 +362,19 @@ def get_cache_directory():
"""
url_suffix = urlparse(server).netloc
reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1])
_cachedir = os.path.join(_cache_directory, reversed_url_suffix)
_cachedir = os.path.join(_root_cache_directory, reversed_url_suffix)
return _cachedir


def set_root_cache_directory(root_cache_directory):
"""Set module-wide base cache directory.

Sets the base cache directory that defines how the actual cache
directory for the server being used is derived. This is
``root_cache_directory / top-level domain / second-level domain /
hostname``, and by default is set to
``root_cache_directory / org / openml / www`` for the standard
OpenML.org server.
Sets the root cache directory, wherin the cache directories are
created to store content from different OpenML servers. For example,
by default, cached data for the standard OpenML.org server is stored
at ``root_cache_directory / org / openml / www``, and the general
pattern is ``root_cache_directory / top-level domain / second-level
domain / hostname``.

Parameters
----------
Expand All @@ -386,8 +386,8 @@ def set_root_cache_directory(root_cache_directory):
get_cache_directory
"""

global _cache_directory
_cache_directory = root_cache_directory
global _root_cache_directory
_root_cache_directory = root_cache_directory


start_using_configuration_for_example = (
Expand Down
2 changes: 1 addition & 1 deletion openml/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def setUp(self, n_levels: int = 1):
self.production_server = "https://openml.org/api/v1/xml"
openml.config.server = TestBase.test_server
openml.config.avoid_duplicate_runs = False
openml.config._cache_directory = self.workdir
openml.config.set_root_cache_directory(self.workdir)

# Increase the number of retries to avoid spurious server failures
self.retry_policy = openml.config.retry_policy
Expand Down
6 changes: 3 additions & 3 deletions tests/test_datasets/test_dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test__get_dataset_description(self):
self.assertTrue(os.path.exists(description_xml_path))

def test__getarff_path_dataset_arff(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
description = _get_dataset_description(self.workdir, 2)
arff_path = _get_dataset_arff(description, cache_directory=self.workdir)
self.assertIsInstance(arff_path, str)
Expand Down Expand Up @@ -494,7 +494,7 @@ def test__get_dataset_parquet_not_cached(self):

@mock.patch("openml._api_calls._download_minio_file")
def test__get_dataset_parquet_is_cached(self, patch):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
patch.side_effect = RuntimeError(
"_download_minio_file should not be called when loading from cache"
)
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_publish_dataset(self):
self.assertIsInstance(dataset.dataset_id, int)

def test__retrieve_class_labels(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
labels = openml.datasets.get_dataset(2, download_data=False).retrieve_class_labels()
self.assertEqual(labels, ["1", "2", "3", "4", "5", "U"])
labels = openml.datasets.get_dataset(2, download_data=False).retrieve_class_labels(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_runs/test_run_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,11 +1569,11 @@ def test_run_on_dataset_with_missing_labels_array(self):
self.assertEqual(len(row), 12)

def test_get_cached_run(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
openml.runs.functions._get_cached_run(1)

def test_get_uncached_run(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
with self.assertRaises(openml.exceptions.OpenMLCacheException):
openml.runs.functions._get_cached_run(10)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_setups/test_setup_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ def test_setuplist_offset(self):
self.assertEqual(len(all), size * 2)

def test_get_cached_setup(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
openml.setups.functions._get_cached_setup(1)

def test_get_uncached_setup(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
with self.assertRaises(openml.exceptions.OpenMLCacheException):
openml.setups.functions._get_cached_setup(10)
10 changes: 5 additions & 5 deletions tests/test_tasks/test_task_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def tearDown(self):
super(TestTask, self).tearDown()

def test__get_cached_tasks(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
tasks = openml.tasks.functions._get_cached_tasks()
self.assertIsInstance(tasks, dict)
self.assertEqual(len(tasks), 3)
self.assertIsInstance(list(tasks.values())[0], OpenMLTask)

def test__get_cached_task(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
task = openml.tasks.functions._get_cached_task(1)
self.assertIsInstance(task, OpenMLTask)

def test__get_cached_task_not_cached(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
self.assertRaisesRegex(
OpenMLCacheException,
"Task file for tid 2 not cached",
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_list_tasks_per_type_paginate(self):
self._check_task(tasks[tid])

def test__get_task(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
openml.tasks.get_task(1882)

@unittest.skip(
Expand Down Expand Up @@ -224,7 +224,7 @@ def assert_and_raise(*args, **kwargs):
self.assertFalse(os.path.exists(os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")))

def test_get_task_with_cache(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
task = openml.tasks.get_task(1)
self.assertIsInstance(task, OpenMLTask)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tasks/test_task_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_tagging(self):
self.assertEqual(len(task_list), 0)

def test_get_train_and_test_split_indices(self):
openml.config._cache_directory = self.static_cache_dir
openml.config.set_root_cache_directory(self.static_cache_dir)
task = openml.tasks.get_task(1882)
train_indices, test_indices = task.get_train_test_split_indices(0, 0)
self.assertEqual(16, train_indices[0])
Expand Down