Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
846a539
style: Fix linting split.py
eddiebergman Jan 8, 2024
053b053
typing: Fix mypy errors split.py
eddiebergman Jan 8, 2024
48d9471
typing: data_feature
eddiebergman Jan 8, 2024
7dbc9b6
typing: trace
eddiebergman Jan 8, 2024
2712c71
more linting fixes
LennartPurucker Jan 8, 2024
e3e432e
Merge branch 'fix_linter_lennart' of https://github.com/openml/openml…
LennartPurucker Jan 8, 2024
69f033e
typing: finish up trace
eddiebergman Jan 8, 2024
798cb8e
typing: config.py
eddiebergman Jan 8, 2024
5fbb36a
typing: More fixes on config.py
eddiebergman Jan 8, 2024
c88f8f4
typing: setup.py
eddiebergman Jan 8, 2024
f911c30
finalize runs linting
LennartPurucker Jan 8, 2024
92d9b26
Merge branch 'fix_linter_lennart' of https://github.com/openml/openml…
LennartPurucker Jan 8, 2024
38bcd5e
typing: evaluation.py
eddiebergman Jan 8, 2024
869f9c4
typing: setup
eddiebergman Jan 8, 2024
abc6117
ruff fixes across different files and mypy fixes for run files
LennartPurucker Jan 8, 2024
54aca64
Merge branch 'fix_linter_lennart' of https://github.com/openml/openml…
LennartPurucker Jan 8, 2024
f6c2ae5
typing: _api_calls
eddiebergman Jan 8, 2024
960afa1
adjust setup files' linting and minor ruff changes
LennartPurucker Jan 8, 2024
bea95cc
Merge branch 'fix_linter_lennart' of https://github.com/openml/openml…
LennartPurucker Jan 8, 2024
5ea4287
typing: utils
eddiebergman Jan 8, 2024
cffd7ed
late night push
LennartPurucker Jan 8, 2024
6d3ae4a
Merge branch 'fix_linter_lennart' of https://github.com/openml/openml…
LennartPurucker Jan 8, 2024
bef753e
typing: utils.py
eddiebergman Jan 8, 2024
1df08b5
typing: tip tap tippity
eddiebergman Jan 9, 2024
d4f79f8
typing: mypy 78, ruff ~200
eddiebergman Jan 9, 2024
cecc746
refactor output format name and minor linting stuff
LennartPurucker Jan 9, 2024
3804220
other: midway merge
eddiebergman Jan 9, 2024
57db7f0
merge
eddiebergman Jan 9, 2024
c9c96b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
bb0cdd1
typing: I'm runnign out of good messages
eddiebergman Jan 9, 2024
e38fdd1
Merge branch 'fix_linter_lennart' of github.com:openml/openml-python …
eddiebergman Jan 9, 2024
dcc60f5
typing: datasets
eddiebergman Jan 9, 2024
a19bc26
leinting for flows and some ruff changes
LennartPurucker Jan 9, 2024
93b83eb
Merge branch 'fix_linter_lennart' of https://github.com/openml/openml…
LennartPurucker Jan 9, 2024
9174f20
no more mypy errors
LennartPurucker Jan 9, 2024
a87109a
ruff runs and setups
LennartPurucker Jan 9, 2024
10a2f5e
typing: Finish off mypy and ruff errors
eddiebergman Jan 9, 2024
66e3c97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
66a3ab1
style: File wide ignores of PLR0913
eddiebergman Jan 9, 2024
290578c
Merge branch 'fix_linter_lennart' of github.com:openml/openml-python …
eddiebergman Jan 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import warnings
from io import StringIO
from pathlib import Path
from typing import Dict, Union, cast
from typing_extensions import Literal
from typing import Any, cast
from typing_extensions import Literal, TypedDict
from urllib.parse import urlparse

logger = logging.getLogger(__name__)
Expand All @@ -21,7 +21,16 @@
file_handler: logging.handlers.RotatingFileHandler | None = None


def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT
class _Config(TypedDict):
apikey: str
server: str
cachedir: Path
avoid_duplicate_runs: bool
retry_policy: Literal["human", "robot"]
connection_n_retries: int | None


def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002
"""Creates but does not attach the log handlers."""
global console_handler, file_handler # noqa: PLW0603
if console_handler is not None or file_handler is not None:
Expand Down Expand Up @@ -91,22 +100,22 @@ def set_file_log_level(file_output_level: int) -> None:

# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards)
_user_path = Path("~").expanduser().absolute()
_defaults = {
_defaults: _Config = {
"apikey": "",
"server": "https://www.openml.org/api/v1/xml",
"cachedir": (
os.environ.get("XDG_CACHE_HOME", _user_path / ".cache" / "openml")
Path(os.environ.get("XDG_CACHE_HOME", _user_path / ".cache" / "openml"))
if platform.system() == "Linux"
else _user_path / ".openml"
),
"avoid_duplicate_runs": "True",
"avoid_duplicate_runs": True,
"retry_policy": "human",
"connection_n_retries": "5",
"connection_n_retries": 5,
}

# Default values are actually added here in the _setup() function which is
# called at the end of this module
server = str(_defaults["server"]) # so mypy knows it is a string
server = _defaults["server"]


def get_server_base_url() -> str:
Expand All @@ -124,10 +133,10 @@ def get_server_base_url() -> str:
apikey: str = _defaults["apikey"]
# The current cache directory (without the server name)
_root_cache_directory = Path(_defaults["cachedir"])
avoid_duplicate_runs: bool = _defaults["avoid_duplicate_runs"] == "True"
avoid_duplicate_runs = _defaults["avoid_duplicate_runs"]

retry_policy = _defaults["retry_policy"]
connection_n_retries = int(_defaults["connection_n_retries"])
connection_n_retries = _defaults["connection_n_retries"]


def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None:
Expand Down Expand Up @@ -216,7 +225,7 @@ def determine_config_file_path() -> Path:
return config_dir / "config"


def _setup(config: dict[str, str | int | bool] | None = None) -> None:
def _setup(config: _Config | None = None) -> None:
"""Setup openml package. Called on first import.

Reads the config file and sets up apikey, server, cache appropriately.
Expand Down Expand Up @@ -244,17 +253,13 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
cache_exists = True

if config is None:
config = cast(Dict[str, Union[str, int, bool]], _parse_config(config_file))
config = cast(Dict[str, Union[str, int, bool]], config)

avoid_duplicate_runs = bool(config.get("avoid_duplicate_runs"))

apikey = str(config["apikey"])
server = str(config["server"])
short_cache_dir = Path(config["cachedir"])
config = _parse_config(config_file)

tmp_n_retries = config["connection_n_retries"]
n_retries = int(tmp_n_retries) if tmp_n_retries is not None else None
avoid_duplicate_runs = config.get("avoid_duplicate_runs", False)
apikey = config["apikey"]
server = config["server"]
short_cache_dir = config["cachedir"]
n_retries = config["connection_n_retries"]

set_retry_policy(config["retry_policy"], n_retries)

Expand All @@ -279,31 +284,32 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
)


def set_field_in_config_file(field: str, value: str) -> None:
def set_field_in_config_file(field: str, value: Any) -> None:
"""Overwrites the `field` in the configuration file with the new `value`."""
if field not in _defaults:
raise ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.")

# TODO(eddiebergman): This use of globals has gone too far
globals()[field] = value
config_file = determine_config_file_path()
config = _parse_config(str(config_file))
config = _parse_config(config_file)
with config_file.open("w") as fh:
for f in _defaults:
# We can't blindly set all values based on globals() because when the user
# sets it through config.FIELD it should not be stored to file.
# There doesn't seem to be a way to avoid writing defaults to file with configparser,
# because it is impossible to distinguish from an explicitly set value that matches
# the default value, to one that was set to its default because it was omitted.
value = config.get("FAKE_SECTION", f)
value = config.get("FAKE_SECTION", f) # type: ignore
if f == field:
value = globals()[f]
fh.write(f"{f} = {value}\n")


def _parse_config(config_file: str | Path) -> dict[str, str]:
def _parse_config(config_file: str | Path) -> _Config:
"""Parse the config file, set up defaults."""
config_file = Path(config_file)
config = configparser.RawConfigParser(defaults=_defaults)
config = configparser.RawConfigParser(defaults=_defaults) # type: ignore

# The ConfigParser requires a [SECTION_HEADER], which we do not expect in our config file.
# Cheat the ConfigParser module by adding a fake section header
Expand All @@ -319,18 +325,18 @@ def _parse_config(config_file: str | Path) -> dict[str, str]:
logger.info("Error opening file %s: %s", config_file, e.args[0])
config_file_.seek(0)
config.read_file(config_file_)
return dict(config.items("FAKE_SECTION"))


def get_config_as_dict() -> dict[str, str | int | bool]:
config = {} # type: Dict[str, Union[str, int, bool]]
config["apikey"] = apikey
config["server"] = server
config["cachedir"] = str(_root_cache_directory)
config["avoid_duplicate_runs"] = avoid_duplicate_runs
config["connection_n_retries"] = connection_n_retries
config["retry_policy"] = retry_policy
return config
return dict(config.items("FAKE_SECTION")) # type: ignore


def get_config_as_dict() -> _Config:
return {
"apikey": apikey,
"server": server,
"cachedir": _root_cache_directory,
"avoid_duplicate_runs": avoid_duplicate_runs,
"connection_n_retries": connection_n_retries,
"retry_policy": retry_policy,
}


# NOTE: For backwards compatibility, we keep the `str`
Expand Down
57 changes: 30 additions & 27 deletions openml/runs/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import IO, Any, Iterator
from typing_extensions import Self

import arff
Expand Down Expand Up @@ -63,7 +63,7 @@ class OpenMLTraceIteration:
evaluation: float
selected: bool

setup_string: str | None = None
setup_string: dict[str, str] | None = None
parameters: dict[str, str | int | float] | None = None

def __post_init__(self) -> None:
Expand All @@ -86,22 +86,15 @@ def __post_init__(self) -> None:

def get_parameters(self) -> dict[str, Any]:
"""Get the parameters of this trace iteration."""
result = {}
# parameters have prefix 'parameter_'

if self.setup_string:
for param in self.setup_string:
key = param[len(PREFIX) :]
# TODO(eddiebergman): I have no idea how this is working
# or if it even does.
# Remove the type ignore below
value = self.setup_string[param] # type: ignore
result[key] = json.loads(value)
else:
assert self.parameters is not None
for param, value in self.parameters.items():
result[param[len(PREFIX) :]] = value
return result
return {
param[len(PREFIX) :]: json.loads(value)
for param, value in self.setup_string.items()
}

assert self.parameters is not None
return {param[len(PREFIX) :]: value for param, value in self.parameters.items()}


class OpenMLRunTrace:
Expand Down Expand Up @@ -280,7 +273,7 @@ def trace_to_arff(self) -> dict[str, Any]:
],
)

arff_dict = OrderedDict()
arff_dict: dict[str, Any] = {}
data = []
for trace_iteration in self.trace_iterations.values():
tmp_list = []
Expand Down Expand Up @@ -341,9 +334,9 @@ def _trace_from_arff_struct(
----------
cls : type
The trace object to be created.
attributes : List[Tuple[str, str]]
attributes : list[tuple[str, str]]
Attribute descriptions.
content : List[List[Union[int, float, str]]]
content : list[list[int | float | str]]]
List of instances.
error_message : str
Error message to raise if `setup_string` is in `attributes`.
Expand Down Expand Up @@ -411,7 +404,7 @@ def _trace_from_arff_struct(
return cls(None, trace)

@classmethod
def trace_from_xml(cls, xml: str | Path) -> OpenMLRunTrace:
def trace_from_xml(cls, xml: str | Path | IO) -> OpenMLRunTrace:
"""Generate trace from xml.

Creates a trace file from the xml description.
Expand All @@ -428,6 +421,9 @@ def trace_from_xml(cls, xml: str | Path) -> OpenMLRunTrace:
Object containing the run id and a dict containing the trace
iterations.
"""
if isinstance(xml, Path):
xml = str(xml.absolute())

result_dict = xmltodict.parse(xml, force_list=("oml:trace_iteration",))["oml:trace"]

run_id = result_dict["oml:run_id"]
Expand Down Expand Up @@ -489,20 +485,27 @@ def merge_traces(cls, traces: list[OpenMLRunTrace]) -> OpenMLRunTrace:
If the parameters in the iterations of the traces being merged are not equal.
If a key (repeat, fold, iteration) is encountered twice while merging the traces.
"""
merged_trace = OrderedDict() # type: OrderedDict[Tuple[int, int, int], OpenMLTraceIteration] # E501
merged_trace: dict[tuple[int, int, int], OpenMLTraceIteration] = {}

previous_iteration = None
for trace in traces:
for iteration in trace:
key = (iteration.repeat, iteration.fold, iteration.iteration)

assert iteration.parameters is not None
param_keys = iteration.parameters.keys()

if previous_iteration is not None:
if list(merged_trace[previous_iteration].parameters.keys()) != list(
iteration.parameters.keys(),
):
trace_itr = merged_trace[previous_iteration]

assert trace_itr.parameters is not None
trace_itr_keys = trace_itr.parameters.keys()

if list(param_keys) != list(trace_itr_keys):
raise ValueError(
"Cannot merge traces because the parameters are not equal: "
"{} vs {}".format(
list(merged_trace[previous_iteration].parameters.keys()),
list(trace_itr.parameters.keys()),
list(iteration.parameters.keys()),
),
)
Expand All @@ -517,11 +520,11 @@ def merge_traces(cls, traces: list[OpenMLRunTrace]) -> OpenMLRunTrace:

return cls(None, merged_trace)

def __repr__(self):
def __repr__(self) -> str:
return "[Run id: {}, {} trace iterations]".format(
-1 if self.run_id is None else self.run_id,
len(self.trace_iterations),
)

def __iter__(self):
def __iter__(self) -> Iterator[OpenMLTraceIteration]:
yield from self.trace_iterations.values()
15 changes: 10 additions & 5 deletions openml/setups/setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# License: BSD 3-Clause
from __future__ import annotations

from typing import Any

import openml.config
import openml.flows


class OpenMLSetup:
Expand All @@ -17,19 +20,21 @@ class OpenMLSetup:
The setting of the parameters
"""

def __init__(self, setup_id, flow_id, parameters):
def __init__(self, setup_id: int, flow_id: int, parameters: dict[str, Any]):
if not isinstance(setup_id, int):
raise ValueError("setup id should be int")

if not isinstance(flow_id, int):
raise ValueError("flow id should be int")

if parameters is not None and not isinstance(parameters, dict):
raise ValueError("parameters should be dict")

self.setup_id = setup_id
self.flow_id = flow_id
self.parameters = parameters

def __repr__(self):
def __repr__(self) -> str:
header = "OpenML Setup"
header = "{}\n{}\n".format(header, "=" * len(header))

Expand All @@ -44,7 +49,7 @@ def __repr__(self):
order = ["Setup ID", "Flow ID", "Flow URL", "# of Parameters"]
fields = [(key, fields[key]) for key in order if key in fields]

longest_field_name_length = max(len(name) for name, value in fields)
longest_field_name_length = max(len(name) for name, _ in fields)
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
body = "\n".join(field_line_format.format(name, value) for name, value in fields)
return header + body
Expand Down Expand Up @@ -75,7 +80,7 @@ class OpenMLParameter:
If the parameter was set, the value that it was set to.
"""

def __init__(
def __init__( # noqa: PLR0913
self,
input_id,
flow_id,
Expand Down Expand Up @@ -130,7 +135,7 @@ def __repr__(self):
]
fields = [(key, fields[key]) for key in order if key in fields]

longest_field_name_length = max(len(name) for name, value in fields)
longest_field_name_length = max(len(name) for name, _ in fields)
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
body = "\n".join(field_line_format.format(name, value) for name, value in fields)
return header + body