Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: context manager to temporarily set config values
  • Loading branch information
eddiebergman committed Oct 17, 2024
commit 14d658411de2b2758dab01e60c9b843a40392da6
15 changes: 14 additions & 1 deletion openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import platform
import shutil
import warnings
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import Any, cast
from typing import Any, Iterator, cast
from typing_extensions import Literal, TypedDict
from urllib.parse import urlparse

Expand Down Expand Up @@ -497,6 +498,18 @@ def set_root_cache_directory(root_cache_directory: str | Path) -> None:
stop_using_configuration_for_example = ConfigurationForExamples.stop_using_configuration_for_example


@contextmanager
def set_context(config: dict[str, Any]) -> Iterator[_Config]:
"""A context manager to temporarily override variables in the configuration."""
existing_config = get_config_as_dict()
merged_config = {**existing_config, **config}

_setup(merged_config) # type: ignore
yield merged_config # type: ignore

_setup(existing_config)


__all__ = [
"get_cache_directory",
"set_root_cache_directory",
Expand Down
7 changes: 5 additions & 2 deletions tests/test_openml/test_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import openml
from openml.config import ConfigurationForExamples
import openml.testing
from openml._api_calls import _download_minio_bucket, API_TOKEN_HELP_LINK

Expand Down Expand Up @@ -118,5 +119,7 @@ def test_authentication_endpoints_requiring_api_key_show_relevant_help_link(
endpoint: str,
method: str,
) -> None:
with pytest.raises(openml.exceptions.OpenMLNotAuthorizedError, match=API_TOKEN_HELP_LINK) as e:
openml._api_calls._perform_api_call(call=endpoint, request_method=method, data=None)
# We need to temporarily disable the API key to test the error message
with openml.config.set_context({"apikey": None}):
with pytest.raises(openml.exceptions.OpenMLNotAuthorizedError, match=API_TOKEN_HELP_LINK):
openml._api_calls._perform_api_call(call=endpoint, request_method=method, data=None)