Skip to content

Commit 99a62f6

Browse files
authored
Predictions (#1128)
* Add easy way to retrieve run predictions * Log addition of ``predictions`` (#1103)
1 parent 493511a commit 99a62f6

3 files changed

Lines changed: 20 additions & 1 deletion

File tree

doc/progress.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Changelog
1111
* FIX#1030: ``pre-commit`` hooks now no longer should issue a warning.
1212
* FIX#1110: Make arguments to ``create_study`` and ``create_suite`` that are defined as optional by the OpenML XSD actually optional.
1313
* MAIN#1088: Do CI for Windows on Github Actions instead of Appveyor.
14-
14+
* ADD#1103: Add a ``predictions`` property to OpenMLRun for easy accessibility of prediction data.
1515

1616

1717
0.12.2

openml/runs/run.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import arff
1010
import numpy as np
11+
import pandas as pd
1112

1213
import openml
1314
import openml._api_calls
@@ -116,6 +117,23 @@ def __init__(
116117
self.predictions_url = predictions_url
117118
self.description_text = description_text
118119
self.run_details = run_details
120+
self._predictions = None
121+
122+
@property
123+
def predictions(self) -> pd.DataFrame:
124+
""" Return a DataFrame with predictions for this run """
125+
if self._predictions is None:
126+
if self.data_content:
127+
arff_dict = self._generate_arff_dict()
128+
elif self.predictions_url:
129+
arff_text = openml._api_calls._download_text_file(self.predictions_url)
130+
arff_dict = arff.loads(arff_text)
131+
else:
132+
raise RuntimeError("Run has no predictions.")
133+
self._predictions = pd.DataFrame(
134+
arff_dict["data"], columns=[name for name, _ in arff_dict["attributes"]]
135+
)
136+
return self._predictions
119137

120138
@property
121139
def id(self) -> Optional[int]:

tests/test_runs/test_run_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed, create
175175
predictions_prime = run_prime._generate_arff_dict()
176176

177177
self._compare_predictions(predictions, predictions_prime)
178+
pd.testing.assert_frame_equal(run.predictions, run_prime.predictions)
178179

179180
def _perform_run(
180181
self,

0 commit comments

Comments
 (0)