Skip to content

Commit aea2832

Browse files
committed
Updating with fixed unit tests from PR #1000
2 parents 90c8de6 + 50ce90e commit aea2832

File tree

13 files changed

+331
-100
lines changed

13 files changed

+331
-100
lines changed

openml/testing.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import shutil
77
import sys
88
import time
9-
from typing import Dict
9+
from typing import Dict, Union, cast
1010
import unittest
1111
import warnings
12+
import pandas as pd
1213

1314
# Currently, importing oslo raises a lot of warning that it will stop working
1415
# under python3.8; remove this once they disappear
@@ -252,6 +253,58 @@ def _check_fold_timing_evaluations(
252253
self.assertLessEqual(evaluation, max_val)
253254

254255

256+
def check_task_existence(
257+
task_type: TaskType, dataset_id: int, target_name: str, **kwargs
258+
) -> Union[int, None]:
259+
"""Checks if any task with exists on test server that matches the meta data.
260+
261+
Parameter
262+
---------
263+
task_type : openml.tasks.TaskType
264+
ID of the task type as detailed `here <https://www.openml.org/search?type=task_type>`_.
265+
- Supervised classification: 1
266+
- Supervised regression: 2
267+
- Learning curve: 3
268+
- Supervised data stream classification: 4
269+
- Clustering: 5
270+
- Machine Learning Challenge: 6
271+
- Survival Analysis: 7
272+
- Subgroup Discovery: 8
273+
dataset_id : int
274+
target_name : str
275+
276+
Return
277+
------
278+
int, None
279+
"""
280+
return_val = None
281+
tasks = openml.tasks.list_tasks(task_type=task_type, output_format="dataframe")
282+
if len(tasks) == 0:
283+
return None
284+
tasks = cast(pd.DataFrame, tasks).loc[tasks["did"] == dataset_id]
285+
if len(tasks) == 0:
286+
return None
287+
tasks = tasks.loc[tasks["target_feature"] == target_name]
288+
if len(tasks) == 0:
289+
return None
290+
task_match = []
291+
for task_id in tasks["tid"].to_list():
292+
task_match.append(task_id)
293+
task = openml.tasks.get_task(task_id)
294+
for k, v in kwargs.items():
295+
if getattr(task, k) != v:
296+
# even if one of the meta-data key mismatches, then task_id is not a match
297+
task_match.pop(-1)
298+
break
299+
# if task_id is retained in the task_match list, it passed all meta key-value matches
300+
if len(task_match) == 1:
301+
return_val = task_id
302+
break
303+
if len(task_match) == 0:
304+
return_val = None
305+
return return_val
306+
307+
255308
try:
256309
from sklearn.impute import SimpleImputer
257310
except ImportError:

openml/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import wraps
1010
import collections
1111

12+
import openml
1213
import openml._api_calls
1314
import openml.exceptions
1415
from . import config

tests/test_datasets/test_dataset_functions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DATASETS_CACHE_DIR_NAME,
3737
)
3838
from openml.datasets import fork_dataset, edit_dataset
39+
from openml.tasks import TaskType, create_task
3940

4041

4142
class TestOpenMLDataset(TestBase):
@@ -1350,7 +1351,7 @@ def test_data_edit_errors(self):
13501351
"original_data_url, default_target_attribute, row_id_attribute, "
13511352
"ignore_attribute or paper_url to edit.",
13521353
edit_dataset,
1353-
data_id=564,
1354+
data_id=64, # blood-transfusion-service-center
13541355
)
13551356
# Check server exception when unknown dataset is provided
13561357
self.assertRaisesRegex(
@@ -1360,15 +1361,32 @@ def test_data_edit_errors(self):
13601361
data_id=999999,
13611362
description="xor operation dataset",
13621363
)
1364+
1365+
# Need to own a dataset to be able to edit meta-data
1366+
# Will be creating a forked version of an existing dataset to allow the unit test user
1367+
# to edit meta-data of a dataset
1368+
did = fork_dataset(1)
1369+
self._wait_for_dataset_being_processed(did)
1370+
TestBase._mark_entity_for_removal("data", did)
1371+
# Need to upload a task attached to this data to test edit failure
1372+
task = create_task(
1373+
task_type=TaskType.SUPERVISED_CLASSIFICATION,
1374+
dataset_id=did,
1375+
target_name="class",
1376+
estimation_procedure_id=1,
1377+
)
1378+
task = task.publish()
1379+
TestBase._mark_entity_for_removal("task", task.task_id)
13631380
# Check server exception when owner/admin edits critical fields of dataset with tasks
13641381
self.assertRaisesRegex(
13651382
OpenMLServerException,
13661383
"Critical features default_target_attribute, row_id_attribute and ignore_attribute "
13671384
"can only be edited for datasets without any tasks.",
13681385
edit_dataset,
1369-
data_id=223,
1386+
data_id=did,
13701387
default_target_attribute="y",
13711388
)
1389+
13721390
# Check server exception when a non-owner or non-admin tries to edit critical fields
13731391
self.assertRaisesRegex(
13741392
OpenMLServerException,

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,7 @@ def test_openml_param_name_to_sklearn(self):
14651465
)
14661466
model = sklearn.pipeline.Pipeline(steps=[("scaler", scaler), ("boosting", boosting)])
14671467
flow = self.extension.model_to_flow(model)
1468-
task = openml.tasks.get_task(115)
1468+
task = openml.tasks.get_task(115) # diabetes; crossvalidation
14691469
run = openml.runs.run_flow_on_task(flow, task)
14701470
run = run.publish()
14711471
TestBase._mark_entity_for_removal("run", run.run_id)
@@ -1561,7 +1561,7 @@ def setUp(self):
15611561
# Test methods for performing runs with this extension module
15621562

15631563
def test_run_model_on_task(self):
1564-
task = openml.tasks.get_task(1)
1564+
task = openml.tasks.get_task(1) # anneal; crossvalidation
15651565
# using most_frequent imputer since dataset has mixed types and to keep things simple
15661566
pipe = sklearn.pipeline.Pipeline(
15671567
[
@@ -1626,7 +1626,7 @@ def test_seed_model_raises(self):
16261626
self.extension.seed_model(model=clf, seed=42)
16271627

16281628
def test_run_model_on_fold_classification_1_array(self):
1629-
task = openml.tasks.get_task(1)
1629+
task = openml.tasks.get_task(1) # anneal; crossvalidation
16301630

16311631
X, y = task.get_X_and_y()
16321632
train_indices, test_indices = task.get_train_test_split_indices(repeat=0, fold=0, sample=0)
@@ -1689,7 +1689,7 @@ def test_run_model_on_fold_classification_1_array(self):
16891689
def test_run_model_on_fold_classification_1_dataframe(self):
16901690
from sklearn.compose import ColumnTransformer
16911691

1692-
task = openml.tasks.get_task(1)
1692+
task = openml.tasks.get_task(1) # anneal; crossvalidation
16931693

16941694
# diff test_run_model_on_fold_classification_1_array()
16951695
X, y = task.get_X_and_y(dataset_format="dataframe")
@@ -1753,7 +1753,7 @@ def test_run_model_on_fold_classification_1_dataframe(self):
17531753
)
17541754

17551755
def test_run_model_on_fold_classification_2(self):
1756-
task = openml.tasks.get_task(7)
1756+
task = openml.tasks.get_task(7) # kr-vs-kp; crossvalidation
17571757

17581758
X, y = task.get_X_and_y()
17591759
train_indices, test_indices = task.get_train_test_split_indices(repeat=0, fold=0, sample=0)
@@ -1815,7 +1815,11 @@ def predict_proba(*args, **kwargs):
18151815
raise AttributeError("predict_proba is not available when " "probability=False")
18161816

18171817
# task 1 (test server) is important: it is a task with an unused class
1818-
tasks = [1, 3, 115]
1818+
tasks = [
1819+
1, # anneal; crossvalidation
1820+
3, # anneal; crossvalidation
1821+
115, # diabetes; crossvalidation
1822+
]
18191823
flow = unittest.mock.Mock()
18201824
flow.name = "dummy"
18211825

@@ -1969,7 +1973,7 @@ def test__extract_trace_data(self):
19691973
"max_iter": [10, 20, 40, 80],
19701974
}
19711975
num_iters = 10
1972-
task = openml.tasks.get_task(20)
1976+
task = openml.tasks.get_task(20) # balance-scale; crossvalidation
19731977
clf = sklearn.model_selection.RandomizedSearchCV(
19741978
sklearn.neural_network.MLPClassifier(), param_grid, num_iters,
19751979
)
@@ -2080,8 +2084,8 @@ def test_run_on_model_with_empty_steps(self):
20802084
from sklearn.compose import ColumnTransformer
20812085

20822086
# testing 'drop', 'passthrough', None as non-actionable sklearn estimators
2083-
dataset = openml.datasets.get_dataset(128)
2084-
task = openml.tasks.get_task(59)
2087+
dataset = openml.datasets.get_dataset(128) # iris
2088+
task = openml.tasks.get_task(59) # mfeat-pixel; crossvalidation
20852089

20862090
X, y, categorical_ind, feature_names = dataset.get_data(
20872091
target=dataset.default_target_attribute, dataset_format="array"
@@ -2198,7 +2202,7 @@ def test_failed_serialization_of_custom_class(self):
21982202
steps=[("preprocess", ct), ("estimator", sklearn.tree.DecisionTreeClassifier())]
21992203
) # build a sklearn classifier
22002204

2201-
task = openml.tasks.get_task(253) # data with mixed types from test server
2205+
task = openml.tasks.get_task(253) # profb; crossvalidation
22022206
try:
22032207
_ = openml.runs.run_model_on_task(clf, task)
22042208
except AttributeError as e:

tests/test_flows/test_flow_functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,15 @@ def test_get_flow_id(self):
345345
with patch("openml.utils._list_all", list_all):
346346
clf = sklearn.tree.DecisionTreeClassifier()
347347
flow = openml.extensions.get_extension_by_model(clf).model_to_flow(clf).publish()
348+
TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name))
349+
TestBase.logger.info(
350+
"collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)
351+
)
348352

349353
self.assertEqual(openml.flows.get_flow_id(model=clf, exact_version=True), flow.flow_id)
350354
flow_ids = openml.flows.get_flow_id(model=clf, exact_version=False)
351355
self.assertIn(flow.flow_id, flow_ids)
352-
self.assertGreater(len(flow_ids), 2)
356+
self.assertGreater(len(flow_ids), 0)
353357

354358
# Check that the output of get_flow_id is identical if only the name is given, no matter
355359
# whether exact_version is set to True or False.
@@ -361,4 +365,3 @@ def test_get_flow_id(self):
361365
)
362366
self.assertEqual(flow_ids_exact_version_True, flow_ids_exact_version_False)
363367
self.assertIn(flow.flow_id, flow_ids_exact_version_True)
364-
self.assertGreater(len(flow_ids_exact_version_True), 2)

tests/test_runs/test_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_to_from_filesystem_vanilla(self):
102102
("classifier", DecisionTreeClassifier(max_depth=1)),
103103
]
104104
)
105-
task = openml.tasks.get_task(119)
105+
task = openml.tasks.get_task(119) # diabetes; crossvalidation
106106
run = openml.runs.run_model_on_task(
107107
model=model,
108108
task=task,
@@ -142,7 +142,7 @@ def test_to_from_filesystem_search(self):
142142
},
143143
)
144144

145-
task = openml.tasks.get_task(119)
145+
task = openml.tasks.get_task(119) # diabetes; crossvalidation
146146
run = openml.runs.run_model_on_task(
147147
model=model, task=task, add_local_measures=False, avoid_duplicate_runs=False,
148148
)
@@ -163,7 +163,7 @@ def test_to_from_filesystem_no_model(self):
163163
model = Pipeline(
164164
[("imputer", SimpleImputer(strategy="mean")), ("classifier", DummyClassifier())]
165165
)
166-
task = openml.tasks.get_task(119)
166+
task = openml.tasks.get_task(119) # diabetes; crossvalidation
167167
run = openml.runs.run_model_on_task(model=model, task=task, add_local_measures=False)
168168

169169
cache_path = os.path.join(self.workdir, "runs", str(random.getrandbits(128)))
@@ -184,7 +184,7 @@ def test_publish_with_local_loaded_flow(self):
184184
model = Pipeline(
185185
[("imputer", SimpleImputer(strategy="mean")), ("classifier", DummyClassifier())]
186186
)
187-
task = openml.tasks.get_task(119)
187+
task = openml.tasks.get_task(119) # diabetes; crossvalidation
188188

189189
# Make sure the flow does not exist on the server yet.
190190
flow = extension.model_to_flow(model)

0 commit comments

Comments
 (0)