diff --git a/tests/test_learning.py b/tests/test_learning.py index 4f618f7c1..e341e699f 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -1,7 +1,114 @@ +import pytest + from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \ PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \ - NeuralNetLearner, PerceptronLearner, DecisionTreeLearner + NeuralNetLearner, PerceptronLearner, DecisionTreeLearner, \ + mean_error, rms_error, ms_error, manhattan_distance, \ + mean_boolean_error, hamming_distance from utils import DataFile +from math import sqrt + + +def test_ms_error(): + predictions = [1, 1, 1] + targets = [1, 1, 1] + + assert ms_error(predictions, targets) == 0 + + predictions = [1, 1, 1] + targets = [2, 2, 2] + + assert ms_error(predictions, targets) == 1 + + predictions = [1, 1, 1, 1] + targets = [1, 3, 3, 1] + + assert ms_error(predictions, targets) == 2 + + +def test_mean_error(): + predictions = [1, 1, 1] + targets = [1, 1, 1] + + assert mean_error(predictions, targets) == 0 + + predictions = [1, 1, 1] + targets = [2, 2, 2] + + assert mean_error(predictions, targets) == 1 + + predictions = [1, 1, 2, 1] + targets = [1, 3, 1, 1] + + assert mean_error(predictions, targets) == 0.75 + + +def test_rms_error(): + predictions = [1, 1, 1] + targets = [1, 1, 1] + + assert rms_error(predictions, targets) == 0 + + predictions = [1, 1, 1] + targets = [2, 2, 2] + + assert rms_error(predictions, targets) == 1 + + predictions = [1, 1, 2, 1] + targets = [1, 2, 1, 1] + + assert rms_error(predictions, targets) == pytest.approx(sqrt(0.5)) + + +def test_manhattan_distance(): + predictions = [1, 1, 1] + targets = [1, 1, 1] + + assert manhattan_distance(predictions, targets) == 0 + + predictions = [1, 1, 1] + targets = [2, 2, 2] + + assert manhattan_distance(predictions, targets) == 3 + + predictions = [1, 1, 2, 3] + targets = [1, 3, 1, 1] + + assert manhattan_distance(predictions, targets) == 5 + + +def test_mean_boolean_error(): + predictions = [1, 1, 1] + targets = [1, 1, 1] + + assert mean_boolean_error(predictions, targets) == 0 + + predictions = [1, 1, 1] + targets = [2, 2, 2] + + assert mean_boolean_error(predictions, targets) == 1 + + predictions = [1, 1, 2, 3] + targets = [1, 3, 1, 2] + + assert mean_boolean_error(predictions, targets) == 0.75 + + +def test_hamming_distance(): + predictions = [1, 1, 1] + targets = [1, 1, 1] + + assert hamming_distance(predictions, targets) == 0 + + predictions = [1, 1, 1] + targets = [2, 2, 2] + + assert hamming_distance(predictions, targets) == 3 + + predictions = [1, 1, 2, 3] + targets = [1, 3, 1, 1] + + assert hamming_distance(predictions, targets) == 3 def test_exclude():