Skip to content

Commit c4e8ec3

Browse files
committed
added truncate method and removed {0:.4f} to truncate floats to 4 decimal places
1 parent f6cf450 commit c4e8ec3

4 files changed

Lines changed: 37 additions & 24 deletions

File tree

probability.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -549,16 +549,15 @@ def forward(HMM, fv, ev):
549549
scalar_vector_product(fv[1], HMM.transition_model[1]))
550550
sensor_dist = HMM.sensor_dist(ev)
551551

552-
return([float("{0:.4f}".format(i)) for i in normalize(element_wise_product(sensor_dist, prediction))])
552+
return(normalize(element_wise_product(sensor_dist, prediction)))
553553

554554

555555
def backward(HMM, b, ev):
556556
sensor_dist = HMM.sensor_dist(ev)
557557
prediction = element_wise_product(sensor_dist, b)
558558

559-
return([float("{0:.4f}".format(i)) for i in normalize(vector_add(
560-
scalar_vector_product(prediction[0], HMM.transition_model[0]),
561-
scalar_vector_product(prediction[1], HMM.transition_model[1])))])
559+
return(normalize(vector_add(scalar_vector_product(prediction[0], HMM.transition_model[0]),
560+
scalar_vector_product(prediction[1], HMM.transition_model[1]))))
562561

563562

564563
def forward_backward(HMM, ev, prior):
@@ -594,10 +593,6 @@ def forward_backward(HMM, ev, prior):
594593
bv.append(b)
595594

596595
sv = sv[::-1]
597-
# to have only 4 digits after decimal point
598-
for i in range(len(sv)):
599-
for j in range(len(sv[i])):
600-
sv[i][j] = float("{0:.4f}".format(sv[i][j]))
601596

602597
return(sv)
603598

tests/test_probability.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def test_forward_backward():
110110
umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor)
111111

112112
umbrella_evidence = [T, T, F, T, T]
113-
assert forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior) == [[0.6469, 0.3531],
114-
[0.8673, 0.1327], [0.8204, 0.1796], [0.3075, 0.6925], [0.8205, 0.1795], [0.8673, 0.1327]]
113+
assert truncate(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [[0.6469, 0.3531],
114+
[0.8673, 0.1327], [0.8204, 0.1796], [0.3075, 0.6925], [0.8204, 0.1796], [0.8673, 0.1327]]
115115

116116
umbrella_evidence = [T, F, T, F, T]
117-
assert forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior) == [[0.5871, 0.4129],
118-
[0.7177, 0.2823], [0.2325, 0.7675], [0.6072, 0.3928], [0.2324, 0.7676], [0.7177, 0.2823]]
117+
assert truncate(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [[0.5871, 0.4129],
118+
[0.7177, 0.2823], [0.2324, 0.7676], [0.6072, 0.3928], [0.2324, 0.7676], [0.7177, 0.2823]]
119119

120120
def test_fixed_lag_smoothing():
121121
umbrella_evidence = [T, F, T, F, T]
@@ -126,16 +126,16 @@ def test_fixed_lag_smoothing():
126126
umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor)
127127

128128
d = 2
129-
assert fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t) == [0.1111, 0.8889]
129+
assert truncate(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.1111, 0.8889]
130130
d = 5
131-
assert fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t) is None
131+
assert truncate(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) is None
132132

133133
umbrella_evidence = [T, T, F, T, T]
134134
# t = 4
135135
e_t = T
136136

137137
d = 1
138-
assert fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t) == [0.9939, 0.0061]
138+
assert truncate(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.9939, 0.0061]
139139

140140

141141
if __name__ == '__main__':

tests/test_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,22 @@ def test_scalar_vector_product():
118118
assert scalar_vector_product(2, [1, 2, 3]) == [2, 4, 6]
119119

120120
def test_scalar_matrix_product():
121-
assert scalar_matrix_product(-5, [[1, 2], [3, 4], [0, 6]]) == [[-5, -10], [-15, -20], [0, -30]]
122-
assert scalar_matrix_product(0.2, [[1, 2], [2, 3]]) == [[0.2, 0.4], [0.4, 0.6]]
121+
assert truncate(scalar_matrix_product(-5, [[1, 2], [3, 4], [0, 6]])) == [[-5, -10], [-15, -20], [0, -30]]
122+
assert truncate(scalar_matrix_product(0.2, [[1, 2], [2, 3]])) == [[0.2, 0.4], [0.4, 0.6]]
123123

124124

125125
def test_inverse_matrix():
126-
assert inverse_matrix([[1, 0], [0, 1]]) == [[1, 0], [0, 1]]
127-
assert inverse_matrix([[2, 1], [4, 3]]) == [[1.5, -0.5], [-2.0, 1.0]]
128-
assert inverse_matrix([[4, 7], [2, 6]]) == [[0.6, -0.7], [-0.2, 0.4]]
126+
assert truncate(inverse_matrix([[1, 0], [0, 1]])) == [[1, 0], [0, 1]]
127+
assert truncate(inverse_matrix([[2, 1], [4, 3]])) == [[1.5, -0.5], [-2.0, 1.0]]
128+
assert truncate(inverse_matrix([[4, 7], [2, 6]])) == [[0.6, -0.7], [-0.2, 0.4]]
129+
130+
def test_truncate():
131+
assert truncate(5.3330000300330) == 5.3330
132+
assert truncate(10.234566) == 10.2346
133+
assert truncate([1.234566, 0.555555, 6.010101]) == [1.2346, 0.5556, 6.0101]
134+
assert truncate([[1.234566, 0.555555, 6.010101],
135+
[10.505050, 12.121212, 6.030303]]) == [[1.2346, 0.5556, 6.0101],
136+
[10.5051, 12.1212, 6.0303]]
129137

130138
def test_num_or_str():
131139
assert num_or_str('42') == 42

utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _mat_mult(X_M, Y_M):
197197
for Y in Y_M:
198198
result = _mat_mult(result, Y)
199199

200-
return([[float("{0:.4f}".format(i)) for i in row] for row in result])
200+
return(result)
201201

202202
def vector_to_diagonal(v):
203203
"""Converts a vector to a diagonal matrix with vector elements
@@ -218,7 +218,7 @@ def scalar_vector_product(X, Y):
218218
return [X*y for y in Y]
219219

220220
def scalar_matrix_product(X, Y):
221-
return([[float("{0:.4f}".format(i)) for i in scalar_vector_product(X, y)] for y in Y])
221+
return([scalar_vector_product(X, y) for y in Y])
222222

223223
def inverse_matrix(X):
224224
"""Inverse a given square matrix of size 2x2"""
@@ -228,7 +228,7 @@ def inverse_matrix(X):
228228
assert det != 0
229229
inv_mat = scalar_matrix_product(1.0/det, [[X[1][1], -X[0][1]], [-X[1][0], X[0][0]]])
230230

231-
return([[float("{0:.4f}".format(i)) for i in row] for row in inv_mat])
231+
return(inv_mat)
232232

233233

234234
def probability(p):
@@ -253,6 +253,16 @@ def weighted_sampler(seq, weights):
253253

254254
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
255255

256+
def truncate(x, n = 4):
257+
"""Truncates floats, vectors, matrices to n decimal values"""
258+
if isinstance(x, float):
259+
return(float("{0:.{1}f}".format(x, n)))
260+
elif isinstance(x, list) and not isinstance(x[0], list):
261+
return([float("{0:.{1}f}".format(i, n)) for i in x])
262+
elif isinstance(x, list) and isinstance(x[0], list):
263+
return([[float("{0:.{1}f}".format(i, n)) for i in row] for row in x])
264+
else:
265+
return x
256266

257267
def num_or_str(x):
258268
"""The argument is a string; convert to a number if
@@ -270,7 +280,7 @@ def num_or_str(x):
270280
def normalize(numbers):
271281
"""Multiply each number by a constant such that the sum is 1.0"""
272282
total = float(sum(numbers))
273-
return([float("{0:.4f}".format(n / total)) for n in numbers])
283+
return([(n / total) for n in numbers])
274284

275285

276286
def clip(x, lowest, highest):

0 commit comments

Comments
 (0)