|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import pathlib |
| 16 | +from collections import OrderedDict |
16 | 17 | from concurrent import futures |
17 | 18 | from datetime import datetime |
18 | 19 |
|
@@ -62,7 +63,7 @@ def test_add_remove_features_success(self): |
62 | 63 | assert len(fs.features) == 1 and fs.features[0].name == "my-feature-2" |
63 | 64 |
|
64 | 65 | def test_remove_feature_failure(self): |
65 | | - with pytest.raises(ValueError): |
| 66 | + with pytest.raises(KeyError): |
66 | 67 | fs = FeatureSet("my-feature-set") |
67 | 68 | fs.drop(name="my-feature-1") |
68 | 69 |
|
@@ -287,6 +288,98 @@ def make_tfx_schema_domain_info_inline(schema): |
287 | 288 | feature.int_domain.MergeFrom(domain_ref_to_int_domain[domain_ref]) |
288 | 289 |
|
289 | 290 |
|
| 291 | +def test_feature_set_class_contains_labels(): |
| 292 | + fs = FeatureSet("my-feature-set", labels={"key1": "val1", "key2": "val2"}) |
| 293 | + assert "key1" in fs.labels.keys() and fs.labels["key1"] == "val1" |
| 294 | + assert "key2" in fs.labels.keys() and fs.labels["key2"] == "val2" |
| 295 | + |
| 296 | + |
| 297 | +def test_feature_class_contains_labels(): |
| 298 | + fs = FeatureSet("my-feature-set", labels={"key1": "val1", "key2": "val2"}) |
| 299 | + fs.add( |
| 300 | + Feature( |
| 301 | + name="my-feature-1", |
| 302 | + dtype=ValueType.INT64, |
| 303 | + labels={"feature_key1": "feature_val1"}, |
| 304 | + ) |
| 305 | + ) |
| 306 | + assert "feature_key1" in fs.features[0].labels.keys() |
| 307 | + assert fs.features[0].labels["feature_key1"] == "feature_val1" |
| 308 | + |
| 309 | + |
| 310 | +def test_feature_set_without_labels_empty_dict(): |
| 311 | + fs = FeatureSet("my-feature-set") |
| 312 | + assert fs.labels == OrderedDict() |
| 313 | + assert len(fs.labels) == 0 |
| 314 | + |
| 315 | + |
| 316 | +def test_feature_without_labels_empty_dict(): |
| 317 | + f = Feature("my feature", dtype=ValueType.INT64) |
| 318 | + assert f.labels == OrderedDict() |
| 319 | + assert len(f.labels) == 0 |
| 320 | + |
| 321 | + |
| 322 | +def test_set_label_feature_set(): |
| 323 | + fs = FeatureSet("my-feature-set") |
| 324 | + fs.set_label("k1", "v1") |
| 325 | + assert fs.labels["k1"] == "v1" |
| 326 | + |
| 327 | + |
| 328 | +def test_set_labels_overwrites_existing(): |
| 329 | + fs = FeatureSet("my-feature-set") |
| 330 | + fs.set_label("k1", "v1") |
| 331 | + fs.set_label("k1", "v2") |
| 332 | + assert fs.labels["k1"] == "v2" |
| 333 | + |
| 334 | + |
| 335 | +def test_remove_labels_empty_failure(): |
| 336 | + fs = FeatureSet("my-feature-set") |
| 337 | + with pytest.raises(KeyError): |
| 338 | + fs.remove_label("key1") |
| 339 | + |
| 340 | + |
| 341 | +def test_remove_labels_invalid_key_failure(): |
| 342 | + fs = FeatureSet("my-feature-set") |
| 343 | + fs.set_label("k1", "v1") |
| 344 | + with pytest.raises(KeyError): |
| 345 | + fs.remove_label("key1") |
| 346 | + |
| 347 | + |
| 348 | +def test_unequal_feature_based_on_labels(): |
| 349 | + f1 = Feature(name="feature-1", dtype=ValueType.INT64, labels={"k1": "v1"}) |
| 350 | + f2 = Feature(name="feature-1", dtype=ValueType.INT64, labels={"k1": "v1"}) |
| 351 | + assert f1 == f2 |
| 352 | + f3 = Feature(name="feature-1", dtype=ValueType.INT64) |
| 353 | + assert f1 != f3 |
| 354 | + f4 = Feature(name="feature-1", dtype=ValueType.INT64, labels={"k1": "notv1"}) |
| 355 | + assert f1 != f4 |
| 356 | + |
| 357 | + |
| 358 | +def test_unequal_feature_set_based_on_labels(): |
| 359 | + fs1 = FeatureSet("my-feature-set") |
| 360 | + fs2 = FeatureSet("my-feature-set") |
| 361 | + assert fs1 == fs2 |
| 362 | + fs1.set_label("k1", "v1") |
| 363 | + fs2.set_label("k1", "v1") |
| 364 | + assert fs1 == fs2 |
| 365 | + fs2.set_label("k1", "unequal") |
| 366 | + assert not fs1 == fs2 |
| 367 | + |
| 368 | + |
| 369 | +def test_unequal_feature_set_other_has_no_labels(): |
| 370 | + fs1 = FeatureSet("my-feature-set") |
| 371 | + fs2 = FeatureSet("my-feature-set") |
| 372 | + assert fs1 == fs2 |
| 373 | + fs1.set_label("k1", "v1") |
| 374 | + assert not fs1 == fs2 |
| 375 | + |
| 376 | + |
| 377 | +def test_unequal_feature_other_has_no_labels(): |
| 378 | + f1 = Feature(name="feature-1", dtype=ValueType.INT64, labels={"k1": "v1"}) |
| 379 | + f2 = Feature(name="feature-1", dtype=ValueType.INT64) |
| 380 | + assert f1 != f2 |
| 381 | + |
| 382 | + |
290 | 383 | class TestFeatureSetRef: |
291 | 384 | def test_from_feature_set(self): |
292 | 385 | feature_set = FeatureSet("test", "test") |
|
0 commit comments