Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,11 @@ def _parse_data_from_arff(
col.append(
self._unpack_categories(X[column_name], categories_names[column_name])
)
elif attribute_dtype[column_name] in ('floating',
'integer'):
elif attribute_dtype[column_name] in ("floating", "integer"):
X_col = X[column_name]
if X_col.min() >= 0 and X_col.max() <= 255:
try:
X_col_uint = X_col.astype('uint8')
X_col_uint = X_col.astype("uint8")
if (X_col == X_col_uint).all():
col.append(X_col_uint)
continue
Expand Down
13 changes: 13 additions & 0 deletions openml/extensions/extension_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,19 @@ def obtain_parameter_values(
- ``oml:component`` : int: flow id to which the parameter belongs
"""

@abstractmethod
def check_if_model_fitted(self, model: Any) -> bool:
"""Returns True/False denoting if the model has already been fitted/trained.

Parameters
----------
model : Any

Returns
-------
bool
"""

################################################################################################
# Abstract methods for hyperparameter optimization

Expand Down
22 changes: 22 additions & 0 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,28 @@ def _seed_current_object(current_value):
model.set_params(**random_states)
return model

def check_if_model_fitted(self, model: Any) -> bool:
"""Returns True/False denoting if the model has already been fitted/trained

Parameters
----------
model : Any

Returns
-------
bool
"""
try:
# check if model is fitted
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted

check_is_fitted(model) # raises a NotFittedError if the model has not been trained
return True
except NotFittedError:
# model is not fitted, as is required
return False

def _run_model_on_fold(
self,
model: Any,
Expand Down
6 changes: 6 additions & 0 deletions openml/runs/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def run_flow_on_task(
run_environment = flow.extension.get_version_information()
tags = ["openml-python", run_environment[1]]

if flow.extension.check_if_model_fitted(flow.model):
warnings.warn(
"The model is already fitted!"
" This might cause inconsistency in comparison of results."
)

# execute the run
res = _run_task_get_arffcontent(
flow=flow,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def test_get_data_no_str_data_for_nparrays(self):

def _check_expected_type(self, dtype, is_cat, col):
if is_cat:
expected_type = 'category'
elif not col.isna().any() and (col.astype('uint8') == col).all():
expected_type = 'uint8'
expected_type = "category"
elif not col.isna().any() and (col.astype("uint8") == col).all():
expected_type = "uint8"
else:
expected_type = 'float64'
expected_type = "float64"

self.assertEqual(dtype.name, expected_type)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_datasets/test_dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ def test_get_dataset_by_name(self):
def test_get_dataset_uint8_dtype(self):
dataset = openml.datasets.get_dataset(1)
self.assertEqual(type(dataset), OpenMLDataset)
self.assertEqual(dataset.name, 'anneal')
self.assertEqual(dataset.name, "anneal")
df, _, _, _ = dataset.get_data()
self.assertEqual(df['carbon'].dtype, 'uint8')
self.assertEqual(df["carbon"].dtype, "uint8")

def test_get_dataset(self):
# This is the only non-lazy load to ensure default behaviour works.
Expand Down