diff --git a/csp.ipynb b/csp.ipynb index 1de9e1312..be3882387 100644 --- a/csp.ipynb +++ b/csp.ipynb @@ -52,7 +52,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "psource(CSP)" @@ -105,7 +107,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "psource(different_values_constraint)" @@ -139,7 +143,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "psource(MapColoringCSP)" @@ -178,9 +184,114 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "def queen_constraint(A, a, B, b):\n",
+ " """Constraint is satisfied (true) if A, B are really the same variable,\n",
+ " or if they are not in the same row, down diagonal, or up diagonal."""\n",
+ " return A == B or (a != b and A + a != B + b and A - a != B - b)\n",
+ "class NQueensCSP(CSP):\n",
+ " """Make a CSP for the nQueens problem for search with min_conflicts.\n",
+ " Suitable for large n, it uses only data structures of size O(n).\n",
+ " Think of placing queens one per column, from left to right.\n",
+ " That means position (x, y) represents (var, val) in the CSP.\n",
+ " The main structures are three arrays to count queens that could conflict:\n",
+ " rows[i] Number of queens in the ith row (i.e val == i)\n",
+ " downs[i] Number of queens in the \\ diagonal\n",
+ " such that their (x, y) coordinates sum to i\n",
+ " ups[i] Number of queens in the / diagonal\n",
+ " such that their (x, y) coordinates have x-y+n-1 = i\n",
+ " We increment/decrement these counts each time a queen is placed/moved from\n",
+ " a row/diagonal. So moving is O(1), as is nconflicts. But choosing\n",
+ " a variable, and a best value for the variable, are each O(n).\n",
+ " If you want, you can keep track of conflicted variables, then variable\n",
+ " selection will also be O(1).\n",
+ " >>> len(backtracking_search(NQueensCSP(8)))\n",
+ " 8\n",
+ " """\n",
+ "\n",
+ " def __init__(self, n):\n",
+ " """Initialize data structures for n Queens."""\n",
+ " CSP.__init__(self, list(range(n)), UniversalDict(list(range(n))),\n",
+ " UniversalDict(list(range(n))), queen_constraint)\n",
+ "\n",
+ " self.rows = [0]*n\n",
+ " self.ups = [0]*(2*n - 1)\n",
+ " self.downs = [0]*(2*n - 1)\n",
+ "\n",
+ " def nconflicts(self, var, val, assignment):\n",
+ " """The number of conflicts, as recorded with each assignment.\n",
+ " Count conflicts in row and in up, down diagonals. If there\n",
+ " is a queen there, it can't conflict with itself, so subtract 3."""\n",
+ " n = len(self.variables)\n",
+ " c = self.rows[val] + self.downs[var+val] + self.ups[var-val+n-1]\n",
+ " if assignment.get(var, None) == val:\n",
+ " c -= 3\n",
+ " return c\n",
+ "\n",
+ " def assign(self, var, val, assignment):\n",
+ " """Assign var, and keep track of conflicts."""\n",
+ " oldval = assignment.get(var, None)\n",
+ " if val != oldval:\n",
+ " if oldval is not None: # Remove old val if there was one\n",
+ " self.record_conflict(assignment, var, oldval, -1)\n",
+ " self.record_conflict(assignment, var, val, +1)\n",
+ " CSP.assign(self, var, val, assignment)\n",
+ "\n",
+ " def unassign(self, var, assignment):\n",
+ " """Remove var from assignment (if it is there) and track conflicts."""\n",
+ " if var in assignment:\n",
+ " self.record_conflict(assignment, var, assignment[var], -1)\n",
+ " CSP.unassign(self, var, assignment)\n",
+ "\n",
+ " def record_conflict(self, assignment, var, val, delta):\n",
+ " """Record conflicts caused by addition or deletion of a Queen."""\n",
+ " n = len(self.variables)\n",
+ " self.rows[val] += delta\n",
+ " self.downs[var + val] += delta\n",
+ " self.ups[var - val + n - 1] += delta\n",
+ "\n",
+ " def display(self, assignment):\n",
+ " """Print the queens and the nconflicts values (for debugging)."""\n",
+ " n = len(self.variables)\n",
+ " for val in range(n):\n",
+ " for var in range(n):\n",
+ " if assignment.get(var, '') == val:\n",
+ " ch = 'Q'\n",
+ " elif (var + val) % 2 == 0:\n",
+ " ch = '.'\n",
+ " else:\n",
+ " ch = '-'\n",
+ " print(ch, end=' ')\n",
+ " print(' ', end=' ')\n",
+ " for var in range(n):\n",
+ " if assignment.get(var, '') == val:\n",
+ " ch = '*'\n",
+ " else:\n",
+ " ch = ' '\n",
+ " print(str(self.nconflicts(var, val, assignment)) + ch, end=' ')\n",
+ " print()\n",
+ "def min_conflicts(csp, max_steps=100000):\n",
+ " """Solve a CSP by stochastic hillclimbing on the number of conflicts."""\n",
+ " # Generate a complete assignment for all variables (probably with conflicts)\n",
+ " csp.current = current = {}\n",
+ " for var in csp.variables:\n",
+ " val = min_conflicts_value(csp, var, current)\n",
+ " csp.assign(var, val, current)\n",
+ " # Now repeatedly choose a random conflicted variable and change it\n",
+ " for i in range(max_steps):\n",
+ " conflicted = csp.conflicted_vars(current)\n",
+ " if not conflicted:\n",
+ " return current\n",
+ " var = random.choice(conflicted)\n",
+ " val = min_conflicts_value(csp, var, current)\n",
+ " csp.assign(var, val, current)\n",
+ " return None\n",
+ "