Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
test: Add safe context manager for environ vriable
  • Loading branch information
eddiebergman committed Oct 14, 2024
commit 572afa3dc9aa41e3dd73b888cd87336586732c8f
2 changes: 1 addition & 1 deletion openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _resolve_default_cache_dir() -> Path:
return Path("~", ".cache", "openml")

# This is the proper XDG_CACHE_HOME directory, but
# we unfortunatly had a problem where we used XDG_CACHE_HOME/org,
# we unfortunately had a problem where we used XDG_CACHE_HOME/org,
# we check heuristically if this old directory still exists and issue
# a warning if it does. There's too much data to move to do this for the user.

Expand Down
52 changes: 33 additions & 19 deletions tests/test_openml/test_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# License: BSD 3-Clause
from __future__ import annotations

from contextlib import contextmanager
import os
import tempfile
import unittest.mock
from copy import copy
from typing import Any, Iterator
from pathlib import Path

import pytest
Expand All @@ -13,6 +15,24 @@
import openml.testing


@contextmanager
def safe_environ_patcher(key: str, value: Any) -> Iterator[None]:
"""Context manager to temporarily set an environment variable.

Safe to errors happening in the yielded to function.
"""
_prev = os.environ.get(key)
os.environ[key] = value
try:
yield
except Exception as e:
raise e
finally:
os.environ.pop(key)
if _prev is not None:
os.environ[key] = _prev


class TestConfig(openml.testing.TestBase):
@unittest.mock.patch("openml.config.openml_logger.warning")
@unittest.mock.patch("openml.config._create_log_handlers")
Expand All @@ -32,26 +52,19 @@ def test_non_writable_home(self, log_handler_mock, warnings_mock):
def test_XDG_directories_do_not_exist(self):
with tempfile.TemporaryDirectory(dir=self.workdir) as td:
# Save previous state
_prev = os.environ.get("XDG_CONFIG_HOME")

os.environ["XDG_CONFIG_HOME"] = str(Path(td) / "fake_xdg_cache_home")
path = Path(td) / "fake_xdg_cache_home"
with safe_environ_patcher("XDG_CONFIG_HOME", str(path)):
expected_config_dir = path / "openml"
expected_determined_config_file_path = expected_config_dir / "config"

expected_config_dir = Path(td) / "fake_xdg_cache_home" / "openml"
expected_determined_config_file_path = expected_config_dir / "config"
# Ensure that it correctly determines the path to the config file
determined_config_file_path = openml.config.determine_config_file_path()
assert determined_config_file_path == expected_determined_config_file_path

# Ensure that it correctly determines the path to the config file
determined_config_file_path = openml.config.determine_config_file_path()
assert determined_config_file_path == expected_determined_config_file_path

# Ensure that setup will create the config folder as the configuration
# will be written to that location.
openml.config._setup()
assert expected_config_dir.exists()

# Reset it
os.environ.pop("XDG_CONFIG_HOME")
if _prev is not None:
os.environ["XDG_CONFIG_HOME"] = _prev
# Ensure that setup will create the config folder as the configuration
# will be written to that location.
openml.config._setup()
assert expected_config_dir.exists()

def test_get_config_as_dict(self):
"""Checks if the current configuration is returned accurately as a dict."""
Expand Down Expand Up @@ -164,7 +177,8 @@ def test_configuration_loads_booleans(tmp_path):

def test_openml_cache_dir_env_var(tmp_path: Path) -> None:
expected_path = tmp_path / "test-cache"
with unittest.mock.patch.dict(os.environ, {"OPENML_CACHE_DIR": str(expected_path)}):

with safe_environ_patcher("OPENML_CACHE_DIR", str(expected_path)):
openml.config._setup()
assert openml.config._root_cache_directory == expected_path
assert openml.config.get_cache_directory() == str(expected_path / "org" / "openml" / "www")