Skip to content

Commit 12a4e92

Browse files
Particle filtering
1 parent cf1da23 commit 12a4e92

1 file changed

Lines changed: 66 additions & 3 deletions

File tree

probability.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,72 @@ def fixed_lag_smoothing(e_t, hmm, d):
605605
unimplemented()
606606

607607

608-
def particle_filtering(e, N, dbn):
609-
"""[Fig. 15.17]"""
610-
unimplemented()
608+
def particle_filtering(e, N, HMM):
609+
"""
610+
Particle filtering considering two states variables
611+
N = 10
612+
umbrella_evidence = T
613+
umbrella_prior = [0.5, 0.5]
614+
umbrella_transition = [[0.7, 0.3], [0.3, 0.7]]
615+
umbrella_sensor = [[0.9, 0.2], [0.1, 0.8]]
616+
umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor)
617+
618+
>>> particle_filtering(umbrella_evidence, N, umbrellaHMM)
619+
['A', 'A', 'A', 'B', 'A', 'A', 'B', 'A', 'A', 'A', 'B']
620+
621+
NOTE: Output is an probabilistic answer, therfore can vary
622+
"""
623+
s = []
624+
dist = [0.5, 0.5]
625+
# State Initialization
626+
s = ['A' if probability(dist[0]) else 'B' for i in range(N)]
627+
# Weight Initialization
628+
w = [0 for i in range(N)]
629+
# STEP 1
630+
# Propagate one step using transition model given prior state
631+
dist = vector_add(scalar_vector_product(dist[0], HMM.transition_model[0]),
632+
scalar_vector_product(dist[1], HMM.transition_model[1]))
633+
# Assign state according to probability
634+
s = ['A' if probability(dist[0]) else 'B' for i in range(N)]
635+
w_tot = 0
636+
# Calculate importance weight given evidence e
637+
for i in range(N):
638+
if s[i] == 'A':
639+
# P(U|A)*P(A)
640+
w_i = HMM.sensor_dist(e)[0]*dist[0]
641+
if s[i] == 'B':
642+
# P(U|B)*P(B)
643+
w_i = HMM.sensor_dist(e)[1]*dist[1]
644+
w[i] = w_i
645+
w_tot += w_i
646+
647+
# Normalize all the weights
648+
for i in range(N):
649+
w[i] = w[i]/w_tot
650+
651+
# Limit weights to 4 digits
652+
for i in range(N):
653+
w[i] = float("{0:.4f}".format(w[i]))
654+
655+
# STEP 2
656+
s = weighted_sample_with_replacement(N, s, w)
657+
return s
658+
659+
660+
def weighted_sample_with_replacement(N, s, w):
661+
"""
662+
Performs Weighted sampling over the paricles given weights of each particle.
663+
We keep on picking random states unitll we fill N number states in new distribution
664+
"""
665+
s_wtd = []
666+
cnt = 0
667+
while (cnt <= N):
668+
# Generate a random number from 0 to N-1
669+
i = random.randint(0, N-1)
670+
if (probability(w[i])):
671+
s_wtd.append(s[i])
672+
cnt += 1
673+
return s_wtd
611674

612675
# _________________________________________________________________________
613676
__doc__ += """

0 commit comments

Comments
 (0)