Skip to content

Commit 62014cd

Browse files
authored
Convert sparse labels to pandas series (#1059)
* Convert sparse labels to pandas series * Handling sparse labels as Series * Handling sparse targets when dataset as arrays * Revamping sparse dataset tests * Removing redundant unit test * Cleaning target column formatting * Minor comment edit
1 parent 10c9dc5 commit 62014cd

2 files changed

Lines changed: 30 additions & 8 deletions

File tree

openml/datasets/dataset.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def _encode_if_category(column):
628628
)
629629
elif array_format == "dataframe":
630630
if scipy.sparse.issparse(data):
631-
return pd.DataFrame.sparse.from_spmatrix(data, columns=attribute_names)
631+
data = pd.DataFrame.sparse.from_spmatrix(data, columns=attribute_names)
632632
else:
633633
data_type = "sparse-data" if scipy.sparse.issparse(data) else "non-sparse data"
634634
logger.warning(
@@ -732,6 +732,7 @@ def get_data(
732732
else:
733733
target = [target]
734734
targets = np.array([True if column in target else False for column in attribute_names])
735+
target_names = np.array([column for column in attribute_names if column in target])
735736
if np.sum(targets) > 1:
736737
raise NotImplementedError(
737738
"Number of requested targets %d is not implemented." % np.sum(targets)
@@ -752,11 +753,17 @@ def get_data(
752753
attribute_names = [att for att, k in zip(attribute_names, targets) if not k]
753754

754755
x = self._convert_array_format(x, dataset_format, attribute_names)
755-
if scipy.sparse.issparse(y):
756-
y = np.asarray(y.todense()).astype(target_dtype).flatten()
757-
y = y.squeeze()
758-
y = self._convert_array_format(y, dataset_format, attribute_names)
756+
if dataset_format == "array" and scipy.sparse.issparse(y):
757+
# scikit-learn requires dense representation of targets
758+
y = np.asarray(y.todense()).astype(target_dtype)
759+
# dense representation of single column sparse arrays become a 2-d array
760+
# need to flatten it to a 1-d array for _convert_array_format()
761+
y = y.squeeze()
762+
y = self._convert_array_format(y, dataset_format, target_names)
759763
y = y.astype(target_dtype) if dataset_format == "array" else y
764+
if len(y.shape) > 1 and y.shape[1] == 1:
765+
# single column targets should be 1-d for both `array` and `dataframe` formats
766+
y = y.squeeze()
760767
data, targets = x, y
761768

762769
return data, targets, categorical, attribute_names

tests/test_datasets/test_dataset.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def setUp(self):
287287

288288
self.sparse_dataset = openml.datasets.get_dataset(4136, download_data=False)
289289

290-
def test_get_sparse_dataset_with_target(self):
290+
def test_get_sparse_dataset_array_with_target(self):
291291
X, y, _, attribute_names = self.sparse_dataset.get_data(
292292
dataset_format="array", target="class"
293293
)
@@ -303,7 +303,22 @@ def test_get_sparse_dataset_with_target(self):
303303
self.assertEqual(len(attribute_names), 20000)
304304
self.assertNotIn("class", attribute_names)
305305

306-
def test_get_sparse_dataset(self):
306+
def test_get_sparse_dataset_dataframe_with_target(self):
307+
X, y, _, attribute_names = self.sparse_dataset.get_data(
308+
dataset_format="dataframe", target="class"
309+
)
310+
self.assertIsInstance(X, pd.DataFrame)
311+
self.assertIsInstance(X.dtypes[0], pd.SparseDtype)
312+
self.assertEqual(X.shape, (600, 20000))
313+
314+
self.assertIsInstance(y, pd.Series)
315+
self.assertIsInstance(y.dtypes, pd.SparseDtype)
316+
self.assertEqual(y.shape, (600,))
317+
318+
self.assertEqual(len(attribute_names), 20000)
319+
self.assertNotIn("class", attribute_names)
320+
321+
def test_get_sparse_dataset_array(self):
307322
rval, _, categorical, attribute_names = self.sparse_dataset.get_data(dataset_format="array")
308323
self.assertTrue(sparse.issparse(rval))
309324
self.assertEqual(rval.dtype, np.float32)
@@ -315,7 +330,7 @@ def test_get_sparse_dataset(self):
315330
self.assertEqual(len(attribute_names), 20001)
316331
self.assertTrue(all([isinstance(att, str) for att in attribute_names]))
317332

318-
def test_get_sparse_dataframe(self):
333+
def test_get_sparse_dataset_dataframe(self):
319334
rval, *_ = self.sparse_dataset.get_data()
320335
self.assertIsInstance(rval, pd.DataFrame)
321336
np.testing.assert_array_equal(

0 commit comments

Comments
 (0)