diff --git a/csp.py b/csp.py index d5f96f80b..ee59d4a6b 100644 --- a/csp.py +++ b/csp.py @@ -160,7 +160,7 @@ def conflicted_vars(self, current): def AC3(csp, queue=None, removals=None): """[Figure 6.3]""" if queue is None: - queue = [(Xi, Xk) for Xi in csp.variables for Xk in csp.neighbors[Xi]] + queue = {(Xi, Xk) for Xi in csp.variables for Xk in csp.neighbors[Xi]} csp.support_pruning() while queue: (Xi, Xj) = queue.pop() @@ -169,7 +169,7 @@ def AC3(csp, queue=None, removals=None): return False for Xk in csp.neighbors[Xi]: if Xk != Xj: - queue.append((Xk, Xi)) + queue.add((Xk, Xi)) return True @@ -243,7 +243,7 @@ def forward_checking(csp, var, value, assignment, removals): def mac(csp, var, value, assignment, removals): """Maintain arc consistency.""" - return AC3(csp, [(X, var) for X in csp.neighbors[var]], removals) + return AC3(csp, {(X, var) for X in csp.neighbors[var]}, removals) # The search, proper @@ -374,7 +374,7 @@ def make_arc_consistent(Xj, Xk, csp): # Found a consistent assignment for val1, keep it keep = True break - + if not keep: # Remove val1 csp.prune(Xj, val1, None) diff --git a/tests/test_csp.py b/tests/test_csp.py index 2bc907b6c..77b35c796 100644 --- a/tests/test_csp.py +++ b/tests/test_csp.py @@ -3,7 +3,6 @@ from csp import * import random - random.seed("aima-python") @@ -174,7 +173,7 @@ def test_csp_conflicted_vars(): def test_revise(): neighbors = parse_neighbors('A: B; B: ') domains = {'A': [0], 'B': [4]} - constraints = lambda X, x, Y, y: x % 2 == 0 and (x+y) == 4 + constraints = lambda X, x, Y, y: x % 2 == 0 and (x + y) == 4 csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) csp.support_pruning() @@ -196,24 +195,24 @@ def test_revise(): def test_AC3(): neighbors = parse_neighbors('A: B; B: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} - constraints = lambda X, x, Y, y: x % 2 == 0 and (x+y) == 4 and y % 2 != 0 + constraints = lambda X, x, Y, y: x % 2 == 0 and (x + y) == 4 and y % 2 != 0 removals = [] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assert AC3(csp, removals=removals) is False - constraints = lambda X, x, Y, y: (x % 2) == 0 and (x+y) == 4 + constraints = lambda X, x, Y, y: (x % 2) == 0 and (x + y) == 4 removals = [] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assert AC3(csp, removals=removals) is True assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)]) - - domains = {'A': [ 2, 4], 'B': [ 3, 5]} - constraints = lambda X, x, Y, y: int(x) > int (y) - removals=[] + + domains = {'A': [2, 4], 'B': [3, 5]} + constraints = lambda X, x, Y, y: int(x) > int(y) + removals = [] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assert AC3(csp, removals=removals) @@ -247,7 +246,7 @@ def test_num_legal_values(): def test_mrv(): neighbors = parse_neighbors('A: B; B: C; C: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [4], 'C': [0, 1, 2, 3, 4]} - constraints = lambda X, x, Y, y: x % 2 == 0 and (x+y) == 4 + constraints = lambda X, x, Y, y: x % 2 == 0 and (x + y) == 4 csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assignment = {'A': 0} @@ -269,13 +268,13 @@ def test_mrv(): def test_unordered_domain_values(): map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') assignment = None - assert unordered_domain_values('A', assignment, map_coloring_test) == ['1', '2', '3'] + assert unordered_domain_values('A', assignment, map_coloring_test) == ['1', '2', '3'] def test_lcv(): neighbors = parse_neighbors('A: B; B: C; C: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4, 5], 'C': [0, 1, 2, 3, 4]} - constraints = lambda X, x, Y, y: x % 2 == 0 and (x+y) == 4 + constraints = lambda X, x, Y, y: x % 2 == 0 and (x + y) == 4 csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assignment = {'A': 0} @@ -347,7 +346,7 @@ def test_min_conflicts(): assert min_conflicts(france) tests = [(usa, None)] * 3 - assert failure_test(min_conflicts, tests) >= 1/3 + assert failure_test(min_conflicts, tests) >= 1 / 3 australia_impossible = MapColoringCSP(list('RG'), 'SA: WA NT Q NSW V; NT: WA Q; NSW: Q V; T: ') assert min_conflicts(australia_impossible, 1000) is None @@ -419,9 +418,9 @@ def test_parse_neighbours(): def test_topological_sort(): root = 'NT' - Sort, Parents = topological_sort(australia,root) + Sort, Parents = topological_sort(australia, root) - assert Sort == ['NT','SA','Q','NSW','V','WA'] + assert Sort == ['NT', 'SA', 'Q', 'NSW', 'V', 'WA'] assert Parents['NT'] == None assert Parents['SA'] == 'NT' assert Parents['Q'] == 'SA' @@ -432,10 +431,11 @@ def test_topological_sort(): def test_tree_csp_solver(): australia_small = MapColoringCSP(list('RB'), - 'NT: WA Q; NSW: Q V') + 'NT: WA Q; NSW: Q V') tcs = tree_csp_solver(australia_small) assert (tcs['NT'] == 'R' and tcs['WA'] == 'B' and tcs['Q'] == 'B' and tcs['NSW'] == 'R' and tcs['V'] == 'B') or \ (tcs['NT'] == 'B' and tcs['WA'] == 'R' and tcs['Q'] == 'R' and tcs['NSW'] == 'B' and tcs['V'] == 'R') + if __name__ == "__main__": pytest.main()