Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
code format
  • Loading branch information
sahithyaravi committed Jul 22, 2020
commit e111dd251135fbd16ac5f5cfd7d182ae63345e91
29 changes: 18 additions & 11 deletions openml/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,18 +877,24 @@ def edit_dataset(
-------
data_id of the existing edited version or the new version created and published"""
if not isinstance(data_id, int):
raise TypeError(
"`data_id` must be of type `int`, not {}.".format(type(data_id))
)
raise TypeError("`data_id` must be of type `int`, not {}.".format(type(data_id)))

# case 1, changing these fields creates a new version of the dataset with changed field
if any(field is not None for field in [data, attributes, default_target_attribute,
row_id_attribute, ignore_attribute]):
if any(
field is not None
for field in [
data,
attributes,
default_target_attribute,
row_id_attribute,
ignore_attribute,
]
):
logger.warning("Creating a new version of dataset, cannot edit existing version")
dataset = get_dataset(data_id)

decoded_arff = dataset._get_arff(format='arff')
data_old = decoded_arff['data']
decoded_arff = dataset._get_arff(format="arff")
data_old = decoded_arff["data"]
data_new = data if data is not None else data_old
dataset_new = create_dataset(
name=dataset.name,
Expand All @@ -898,7 +904,7 @@ def edit_dataset(
collection_date=collection_date or dataset.collection_date,
language=language or dataset.language,
licence=dataset.licence,
attributes=attributes or decoded_arff['attributes'],
attributes=attributes or decoded_arff["attributes"],
data=data_new,
default_target_attribute=default_target_attribute or dataset.default_target_attribute,
ignore_attribute=ignore_attribute or dataset.ignore_attribute,
Expand All @@ -907,7 +913,7 @@ def edit_dataset(
original_data_url=original_data_url or dataset.original_data_url,
paper_url=paper_url or dataset.paper_url,
update_comment=dataset.update_comment,
version_label=dataset.version_label
version_label=dataset.version_label,
)
dataset_new.publish()
return dataset_new.dataset_id
Expand All @@ -933,7 +939,9 @@ def edit_dataset(
del xml["oml:data_edit_parameters"][k]

file_elements = {"edit_parameters": ("description.xml", xmltodict.unparse(xml))}
result_xml = openml._api_calls._perform_api_call("data/edit", "post", data=form_data, file_elements=file_elements)
result_xml = openml._api_calls._perform_api_call(
"data/edit", "post", data=form_data, file_elements=file_elements
)
result = xmltodict.parse(result_xml)
data_id = result["oml:data_edit"]["oml:id"]
return int(data_id)
Expand Down Expand Up @@ -1195,4 +1203,3 @@ def _get_online_dataset_format(dataset_id):
dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id, "get")
# build a dict from the xml and get the format from the dataset description
return xmltodict.parse(dataset_xml)["oml:data_set_description"]["oml:format"].lower()

29 changes: 21 additions & 8 deletions tests/test_datasets/test_dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,10 +1341,16 @@ def test_data_edit(self):

# case 1, existing version edit
did = 564
result = edit_dataset(did, description="xor dataset represents XOR operation",
contributor="", collection_date="2019-10-29 17:06:18",
original_data_url="https://www.kaggle.com/ancientaxe/and-or-xor", paper_url="",
citation="kaggle", language="English")
result = edit_dataset(
did,
Comment thread
PGijsbers marked this conversation as resolved.
description="xor dataset represents XOR operation",
contributor="",
collection_date="2019-10-29 17:06:18",
original_data_url="https://www.kaggle.com/ancientaxe/and-or-xor",
paper_url="",
citation="kaggle",
language="English",
)
self.assertEqual(result, did)

# case 2, new version edit
Expand All @@ -1355,8 +1361,15 @@ def test_data_edit(self):
("y", "REAL"),
]
desc = "xor dataset represents XOR operation"
result = edit_dataset(564, description=desc,
contributor="", collection_date="2019-10-29 17:06:18", attributes=column_names,
original_data_url="https://www.kaggle.com/ancientaxe/and-or-xor", paper_url="",
citation="kaggle", language="English")
result = edit_dataset(
564,
description=desc,
contributor="",
collection_date="2019-10-29 17:06:18",
attributes=column_names,
original_data_url="https://www.kaggle.com/ancientaxe/and-or-xor",
paper_url="",
citation="kaggle",
language="English",
)
self.assertNotEqual(did, result)