Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion tests/test_learning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,114 @@
import pytest

from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \
PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \
NeuralNetLearner, PerceptronLearner, DecisionTreeLearner
NeuralNetLearner, PerceptronLearner, DecisionTreeLearner, \
mean_error, rms_error, ms_error, manhattan_distance, \
mean_boolean_error, hamming_distance
from utils import DataFile
from math import sqrt


def test_ms_error():
predictions = [1, 1, 1]
targets = [1, 1, 1]

assert ms_error(predictions, targets) == 0

predictions = [1, 1, 1]
targets = [2, 2, 2]

assert ms_error(predictions, targets) == 1

predictions = [1, 1, 1, 1]
targets = [1, 3, 3, 1]

assert ms_error(predictions, targets) == 2


def test_mean_error():
predictions = [1, 1, 1]
targets = [1, 1, 1]

assert mean_error(predictions, targets) == 0

predictions = [1, 1, 1]
targets = [2, 2, 2]

assert mean_error(predictions, targets) == 1

predictions = [1, 1, 2, 1]
targets = [1, 3, 1, 1]

assert mean_error(predictions, targets) == 0.75


def test_rms_error():
predictions = [1, 1, 1]
targets = [1, 1, 1]

assert rms_error(predictions, targets) == 0

predictions = [1, 1, 1]
targets = [2, 2, 2]

assert rms_error(predictions, targets) == 1

predictions = [1, 1, 2, 1]
targets = [1, 2, 1, 1]

assert rms_error(predictions, targets) == pytest.approx(sqrt(0.5))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest will have to be removed.



def test_manhattan_distance():
predictions = [1, 1, 1]
targets = [1, 1, 1]

assert manhattan_distance(predictions, targets) == 0

predictions = [1, 1, 1]
targets = [2, 2, 2]

assert manhattan_distance(predictions, targets) == 3

predictions = [1, 1, 2, 3]
targets = [1, 3, 1, 1]

assert manhattan_distance(predictions, targets) == 5


def test_mean_boolean_error():
predictions = [1, 1, 1]
targets = [1, 1, 1]

assert mean_boolean_error(predictions, targets) == 0

predictions = [1, 1, 1]
targets = [2, 2, 2]

assert mean_boolean_error(predictions, targets) == 1

predictions = [1, 1, 2, 3]
targets = [1, 3, 1, 2]

assert mean_boolean_error(predictions, targets) == 0.75


def test_hamming_distance():
predictions = [1, 1, 1]
targets = [1, 1, 1]

assert hamming_distance(predictions, targets) == 0

predictions = [1, 1, 1]
targets = [2, 2, 2]

assert hamming_distance(predictions, targets) == 3

predictions = [1, 1, 2, 3]
targets = [1, 3, 1, 1]

assert hamming_distance(predictions, targets) == 3


def test_exclude():
Expand Down