{ "metadata": { "name": "testing ad3 cython bindings" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "A simple grid MRF with Potts potentials in AD3" ] }, { "cell_type": "code", "collapsed": false, "input": [ "import itertools\n", "import ad3" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 1 }, { "cell_type": "raw", "metadata": {}, "source": [ "Set parameters of the model" ] }, { "cell_type": "code", "collapsed": false, "input": [ "grid_size = 20\n", "num_states = 5" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "raw", "metadata": {}, "source": [ "Create a factor graph" ] }, { "cell_type": "code", "collapsed": false, "input": [ "factor_graph = ad3.PFactorGraph()" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 3 }, { "cell_type": "raw", "metadata": {}, "source": [ "Create variables lying on a grid with random potentials." ] }, { "cell_type": "code", "collapsed": false, "input": [ "multi_variables = []\n", "random_grid = np.random.uniform(size=(grid_size, grid_size, num_states))\n", "for i in xrange(grid_size):\n", " multi_variables.append([])\n", " for j in xrange(grid_size):\n", " new_variable = factor_graph.create_multi_variable(num_states)\n", " for state in xrange(num_states):\n", " new_variable.set_log_potential(state, random_grid[i, j, state])\n", " multi_variables[i].append(new_variable)\n", " " ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 4 }, { "cell_type": "raw", "metadata": {}, "source": [ "Create potts potentials for edges." ] }, { "cell_type": "code", "collapsed": false, "input": [ "alpha = .5\n", "potts_matrix = alpha * np.eye(num_states)\n", "potts_potentials = potts_matrix.ravel().tolist()" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 5 }, { "cell_type": "raw", "metadata": {}, "source": [ "Create factors for edges from potts potentials" ] }, { "cell_type": "code", "collapsed": false, "input": [ "for i, j in itertools.product(xrange(grid_size), repeat=2):\n", " if (j > 0):\n", " #horizontal edge\n", " edge_variables = [multi_variables[i][j - 1], multi_variables[i][j]]\n", " factor_graph.create_factor_dense(edge_variables, potts_potentials)\n", " \n", " if (i > 0):\n", " #horizontal edge\n", " edge_variables = [multi_variables[i - 1][j], multi_variables[i][j]]\n", " factor_graph.create_factor_dense(edge_variables, potts_potentials)\n", " " ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 6 }, { "cell_type": "raw", "metadata": {}, "source": [ "Set model parameters and compute the map using AD3" ] }, { "cell_type": "code", "collapsed": false, "input": [ "factor_graph.set_eta_ad3(.1)\n", "factor_graph.adapt_eta_ad3(True)\n", "factor_graph.set_max_iterations_ad3(1000)\n", "value, marginals, edge_marginals = factor_graph.solve_lp_map_ad3()" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 7 }, { "cell_type": "raw", "metadata": {}, "source": [ "Visualize resulting MAP" ] }, { "cell_type": "code", "collapsed": false, "input": [ "res = np.array(marginals).reshape(20, 20, 5)\n", "plt.matshow(np.argmax(random_grid, axis=-1), vmin=0, vmax=4)\n", "plt.matshow(np.argmax(res, axis=-1), vmin=0, vmax=4)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "pyout", "prompt_number": 13, "text": [ "" ] }, { "output_type": "display_data", "png": "iVBORw0KGgoAAAANSUhEUgAAAPwAAAD5CAYAAAADZljUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADnpJREFUeJzt3U9sVPUaxvHnXKSstKkSapP+I4bYmlTSmNLGpNh0YRqM\nbdnJwkWHhdY2Gqsurin2NDa3cacpDXRB3cHCBVES00YXA/eSSFkUF6ZGogKJQUNdSAkkV8y5i0KR\nS5medzq/Gej7/SSTwHA85zdzzuNMh2deoiRJEgFw4R+lXgCA4iHwgCMEHnCEwAOOEHjAEQIPOFK0\nwJ86dUqNjY3asWOHJiYminXYoqmvr9ezzz6r5uZm7dq1q9TLWZdMJqPKyko1NTWt3Le0tKSenh7V\n1taqt7dX165dK+EK12e1xxfHsaqrq9Xc3Kzm5mbNzMyUcIXhFC3wb731lqampvT1119rcnJSi4uL\nxTp0UURRpGw2q/n5ec3NzZV6OevS19d3zwV/6NAh1dbW6vz586qurtbhw4dLtLr1W+3xRVGkoaEh\nzc/Pa35+Xl1dXSVaXVhFCfwff/whSdq9e7fq6ur04osv6syZM8U4dFFtlA5Te3u7Kioq7rpvbm5O\n+/fv15YtW5TJZB7q87fa45M2zvnLpSiBP3v2rBoaGlZ+/8wzz+ibb74pxqGLJooidXZ2qre3V198\n8UWpl1Nwfz+HDQ0ND/27mNVMTEyora1NH330kZaWlkq9nCD40K5ATp8+rW+//Vbj4+MaGhrSr7/+\nWuolFdRGf/Xr7+/Xzz//rNnZWf3444+ampoq9ZKCKErgW1pa9P3336/8/rvvvlNbW1sxDl00VVVV\nkqTGxkZ1d3frxIkTJV5RYbW0tGhhYUGStLCwoJaWlhKvqLC2bdumKIpUXl6ugYEBHT9+vNRLCqIo\ngS8vL5e0/En9hQsX9NVXX6m1tbUYhy6K69evr7wFvHLlimZnZzfchz6tra2anp7WjRs3ND09veH+\nh3358mVJ0s2bN3X06FHt2bOnxCsKJCmSbDabNDQ0JE899VTyySefFOuwRfHTTz8lO3fuTHbu3Jl0\ndnYmR44cKfWS1uWVV15JqqqqkrKysqS6ujqZnp5Orl69mnR3dyc1NTVJT09PsrS0VOpl5u3249u8\neXNSXV2dHDlyJHn11VeTpqam5Lnnnkvefvvt5Pfffy/1MoOIkmSD/3AGYAUf2gGOEHjAkUdC7DSK\nohC7BZBCrp/SgwRekhSvctBsLHXE99w9Eqf/H8RfyT9Nyxgb/Zdp+7SSrXevOf5Siu/zwe6BgfRr\nHovKbOs4OJp622hwJO/95np8FtGi8SOjOF7/QVdheXzred7WYnk+0uRkraPn/ZZ+o38ZBtiI8g78\nRv8yDLAR5RX4vL8MU9+Rz+EeCh07Sr2CsHh8G0NeP8Pf78swL7300p2NsvGdX9d33LltUBv9guHx\nPZgu3LqlFe5Du1U+nANQWPW3bredXGP7vN7Se/gyDLAR5RX4jf5lGGCjyvst/ccff6zXXntNf/75\np958801t3bq1kOsCEEDegX/hhRdWvh8N4OEQsGkXp97U0p6zNtGUfhmmxl+k9O2rkcH07avkYOpN\nJUnxYPpth5P/pt72gAznxNBm/P+G4prrsFwbplal5ZzY2nMPMr48AzhC4AFHCDzgCIEHHCHwgCME\nHnCEwAOOEHjAEQIPOELgAUeC/EMUURStPsTyPqx1y1BMAxYtwxUDDWKUbM+dpYZrERvqwJaBnpKt\nLmupRo8ars+QLOcv1fUZRzmn1vIKDzhC4AFHCDzgCIEHHCHwgCMEHnCEwAOOEHjAEQIPOELgAUcI\nPOBIsC79sGG88KZoPPW2lt62ZOvHm7rYhjHVljHH0WD6/Vr3bWHtvKf14WT6c21lWXOojr71+rRI\n812IUYkuPYBlBB5whMADjhB4wBECDzhC4AFHCDzgCIEHHCHwgCMEHnAkWLU2MVQMLfVX60hry2jm\nULXIYOOvJY0ofbU21DhpS13WOio7VIXZIuT1aZHmnIxF41RrASwj8IAjBB5whMADjhB4wBECDzhC\n4AFHCDzgCIEHHCHwgCPBqrWW2auWyqd1qqulqhqqFmlecyCWGq6l0jqc/Def5RScZfqxxV+GCcyW\nabhS4a+5aJCptQBueSTf/7C+vl6PPfaYNm3apM2bN2tubq6Q6wIQQN6Bj6JI2WxWjz/+eCHXAyCg\ndb2lD/DjP4CA1vUK39nZqe3btyuTyai7u/uuP8/+7df1t24ACit7fvmWVt6BP336tKqqqrSwsKCX\nX35Zu3bt0pNPPrny5x357hhAah07lm+3jX6Ze/u839JXVVVJkhobG9Xd3a0TJ07kuysARZJX4K9f\nv66lpSVJ0pUrVzQ7O6uurq6CLgxA4eX1lv63337T3r17JUlPPPGE3nnnHdXU1BR0YQAKL6/Ab9++\nXefOnSv0WgAElveHdmux1BGjqCz9jo1TXS3VRVMF1rCOEaVfg+V5k2x10lB12THD+bPUeyXb8zEa\np/9r4pE4TI16eOR9238wmX7TVJOEB3NfD1RrAUcIPOAIgQccIfCAIwQecITAA44QeMARAg84QuAB\nRwg84EiwqbWJYRKtRap64d+EmmRqmrS7GG4ykKk6HGgdljrph5O282E536FqxsnB9HVg64Riy77T\nHZ+ptQBuIfCAIwQecITAA44QeMARAg84QuABRwg84AiBBxwh8IAjBB5wJNiY6lAd6E1rjOFdD8uY\nYw2m77APJ+m75paRz5J0wDDGeVjp12Hrpad/3jbFtvM3tviv1NsOG07fsNKP4Y5G0+/Ycq4lyTIt\nO931mXuHvMIDjhB4wBECDzhC4AFHCDzgCIEHHCHwgCMEHnCEwAOOEHjAkWDVWou/DPXQoAwTg01j\nqkfT10OTg4aupWzVTMuaLZVdy/NmFsepN7XUdi01assocMv5CLGOtbbgFR5whMADjhB4wBECDzhC\n4AFHCDzgCIEHHCHwgCMEHnCEwAOOREmSGGZ9ptxpFCmxVE8XC76EOwzVTIvhJP3U0w8n01c+40Hb\nOkY1kn7jQM+FZb/JQVsP1/J8mJ4Lg5GA3eFC18rHonHlijSv8IAjOQOfyWRUWVmppqamlfuWlpbU\n09Oj2tpa9fb26tq1a8EXCaAwcga+r69PMzMzd9136NAh1dbW6vz586qurtbhw4eDLhBA4eQMfHt7\nuyoqKu66b25uTvv379eWLVuUyWR05syZoAsEUDjm78OfPXtWDQ0NkqSGhgbNzc2tul385Z1fd+xY\nvgEorIvZi7qYvZR6e3Pg036oH++x7hmAVV1Hneo66lZ+f2r0Pzm3N39K39LSooWFBUnSwsKCWlpa\nrLsAUCLmwLe2tmp6elo3btzQ9PS02traQqwLQAA5A79v3z49//zz+uGHH1RTU6NPP/1U/f39unTp\nkp5++mn98ssvev3114u1VgDrlPNn+GPHjq16/+effx5kMQDCClattZQcLfXCTVH6mqqVpZppqdZa\nWGq4VsEqzA9ItdbiQTjXVmNpph/HEdVaAMsIPOAIgQccIfCAIwQecITAA44QeMARAg84QuABRwg8\n4Ij5+/BpWeqyljqprbQr26TWQFNdLY/vwIBtimmquqUTlmtuWGHqsubzYbjmhpP31z7+GrvjFR5w\nhMADjhB4wBECDzhC4AFHCDzgCIEHHCHwgCMEHnCEwAOOEHjAkWBd+pDjli2SrVHqbQ8YutiWzvSY\nDP3qKE6/rSQp/faWEdHRYPrvLIwo/X6tY6dH4/SjtUei9Ofa0ru3sFxvkvGai8qsy7kHr/CAIwQe\ncITAA44QeMARAg84QuABRwg84AiBBxwh8IAjBB5wJFi11jJu2VJTHZGtumipiCaGiqhlWnYhKpH3\nZRqtbXh8liUcDLLb5X0bzrdlhHkyaagZL6av947FtnM9PBBmXPb98AoPOELgAUcIPOAIgQccIfCA\nIwQecITAA44QeMARAg84QuABR4JVay0skz5j684DVU9D1YEtU1olaSQ2VE8N+04OGs6JcRJtKJbp\nuRahJh9LgWvXq+AVHnAkZ+AzmYwqKyvV1NS0cl8cx6qurlZzc7Oam5s1MzMTfJEACiNn4Pv6+u4J\ndBRFGhoa0vz8vObn59XV1RV0gQAKJ2fg29vbVVFRcc/9SWL7ORPAgyGvD+0mJib02Wefae/evXrj\njTf06KOP3rPNqfjfK7+u66hVXUdd/qsEcB8Xbt3SMQe+v79fH3zwga5evar33ntPU1NTevfdd+/Z\nbnfcbt01ALP6W7fbTubc2vwp/bZt2xRFkcrLyzUwMKDjx49bdwGgRMyBv3z5siTp5s2bOnr0qPbs\n2VPwRQEII+db+n379unkyZNaXFxUTU2NRkdHlc1mde7cOZWVlWn37t3q7+8v1loBrFPOwB87duye\n+zKZTLDFAAgrSgL8HVsURUoMk0wtU0EtVVLJNlE1VEXUUpe1Pr5QLGseHnk/9baWSrJkrF0bzt9f\nxgpsWpuicdP2luszzSTosWg851+bU60FHCHwgCMEHnCEwAOOEHjAEQIPOELgAUcIPOAIgQccIfCA\nI8Gm1lpqjpYJqZFGbAsZDDPJ1FKJHB2M029rnFprqZ5Gg8bnLiVbndQ2pTXUpN0DhjV8OJn+8Vmn\n1saG525scP0TbnmFBxwh8IAjBB5whMADjhB4wBECDzhC4AFHCDzgCIEHHCHwgCMEHnAk2JhqWTrv\ncZx6U0t3XLJ1+kcDrdkyxtk65tiy5uRg+u8VhBr5/KA8vlCso84LPcJ8VLn/dWde4QFHCDzgCIEH\nHCHwgCMEHnCEwAOOEHjAEQIPOELgAUcIPODIA1GtHVH6SqSlxilJY9H6R/uuylCtTVOJvM06ptrC\nUvG1jGa2iBbDPT6LUM+FtVprvZ7XMhaNU60FsIzAA44QeMARAg84QuABR4oc+AvFPVwxXciWegVB\nZc+XegVhXcxeLPUSioLAFwqBf6hdzF4q9RKKgrf0gCMEHnAkYNMOQCnkivQjxT4ggNLhLT3gCIEH\nHCHwgCMEHnCEwAOOEHjAkf8BMB2Q8Aa5Sy4AAAAASUVORK5CYII=\n" }, { "output_type": "display_data", "png": "iVBORw0KGgoAAAANSUhEUgAAAPwAAAD5CAYAAAADZljUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACiNJREFUeJzt201oFAcfx/HfPDV6akNsNQ1skhWxboQ0LGWbUEgMOZSQ\nkhdvevCQ9aBRsTRtDz01wgPSW0taNAfXmz30EGqgJNjDKAq6HlYRWTGoiSDxrYe6Yg5NmR5sQ/KY\nZ5Ns5qXZ//cDC8lknPkv8ZudGWYcz/M8ATDhP1EPACA8BA8YQvCAIQQPGELwgCEEDxgSWvAXL15U\nQ0ODduzYoeHh4bB2G5p4PK73339fyWRSH374YdTjrEk6nVZ1dbUaGxvnlxUKBfX29qqurk59fX16\n8eJFhBOuzVLvb2hoSLFYTMlkUslkUuPj4xFOGJzQgv/00081MjKiX3/9VT/88IOePXsW1q5D4TiO\nXNdVLpdTNpuNepw16e/vf+0//MmTJ1VXV6fJyUnFYjGdOnUqounWbqn35ziOBgcHlcvllMvl1NnZ\nGdF0wQol+N9//12S1NbWpvr6en388ce6evVqGLsOVbncw9Ta2qqqqqpFy7LZrA4cOKBNmzYpnU6v\n69/fUu9PKp/fXzGhBH/t2jUlEon573ft2qUrV66EsevQOI6jjo4O9fX16dy5c1GP47uFv8NEIrHu\nj2KWMjw8rJaWFn3zzTcqFApRjxMILtr55PLly7px44ZOnDihwcFBPXr0KOqRfFXun34DAwO6f/++\nJiYmdPfuXY2MjEQ9UiBCCT6VSun27dvz39+6dUstLS1h7Do0NTU1kqSGhgb19PRobGws4on8lUql\nlM/nJUn5fF6pVCriify1detWOY6jyspKHTlyRKOjo1GPFIhQgq+srJT06kr91NSUzp8/r+bm5jB2\nHYqXL1/OHwI+ffpUExMTZXfRp7m5WZlMRrOzs8pkMmX3B3tmZkaSNDc3p7Nnz6qrqyviiQLihcR1\nXS+RSHjbt2/3vvvuu7B2G4p79+55TU1NXlNTk9fR0eGdPn066pHWZO/evV5NTY23ceNGLxaLeZlM\nxnv+/LnX09Pj1dbWer29vV6hUIh6zJL98/4qKiq8WCzmnT592tu/f7/X2NjoffDBB95nn33m/fbb\nb1GPGQjH88r85AzAPC7aAYYQPGDIhiA26jhOEJsFsALFztIDCV6Svl5imSupPagdRsxVeb23oe//\n5/tfpKEyvXAtlc/7c44W/3nJh/Tl/jAMUI5KDr7cH4YBylFJwZf6MEy8lJ2tE/GoBwhY+46oJwhW\nub+/f5R0Dv//Hob55JNP5pe5C9aPL3iVq3jUAwSs3INYr+/PnXz1WqnALtq1B7VhAPPadyz+Y3X8\nl+Lrl3RIb+FhGKAclRR8uT8MA5Srkg/pv/32Wx08eFB//PGHjh07pnfeecfPuQAEoOTgd+/ePf98\nNID1IbCLdljfhpa5Y2vRut8vvw7+HXh4BjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOAB\nQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFD\nCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITgAUMI\nHjCE4AFDNkQ9ANa/oaNRTxCsoe+jnsA/fMIDhpT8CR+Px/XWW2/pjTfeUEVFhbLZrJ9zAQhAycE7\njiPXdbV582Y/5wEQoDUd0nue59ccAEKwpk/4jo4Obdu2Tel0Wj09PYt+7i74Ov73C4C/3MlXr5Vy\nvBI/pmdmZlRTU6N8Pq/u7m5dunRJ77777quNOo6+LmWjwL/QerpK7xwtfuRd8iF9TU2NJKmhoUE9\nPT0aGxsrdVMAQlJS8C9fvlShUJAkPX36VBMTE+rs7PR1MAD+K+kc/vHjx9qzZ48k6e2339bnn3+u\n2tpaXwcD4L+Sgt+2bZuuX7/u9ywAAsattcAyyunWYW6tBQwheMAQggcMIXjAEIIHDCF4wBCCBwwh\neMAQggcMIXjAEIIHDCF4wBCCBwwheMAQggcMIXjAEIIHDCF4wBCCBwwheMAQggcMIXjAEIIHDCF4\nwBCCBwwheMAQggcMIXjAEIIHDCF4wBCCBwwheMAQggcMIXjAEIIHDCF4wBCCBwwheMAQggcMIXjA\nEIIHDCF4wBCCBwwheMCQDVEPAFj2p/eVvxt0ThT9MZ/wgCFFg0+n06qurlZjY+P8skKhoN7eXtXV\n1amvr08vXrwIfEgA/igafH9/v8bHxxctO3nypOrq6jQ5OalYLKZTp04FOiAA/xQNvrW1VVVVVYuW\nZbNZHThwQJs2bVI6ndbVq1cDHRCAf1Z90e7atWtKJBKSpEQioWw2u+R67oKv43+/APhr2p3WtPtg\nxeuvOnjP81a0XvtqNwxg1erb61XfXj///cXjl4quv+qr9KlUSvl8XpKUz+eVSqVWuwkAEVl18M3N\nzcpkMpqdnVUmk1FLS0sQcwEIQNHg9+3bp48++kh37txRbW2tzpw5o4GBAT148EA7d+7Uw4cPdejQ\nobBmBbBGRc/hf/zxxyWX//zzz4EMAyBY3FoLLMP3218jxK21gCEEDxhC8IAhBA8YQvCAIQQPGELw\ngCEEDxhC8IAhBA8YQvCAIQQPGELwgCEEDxhC8IAhBA8YQvCAIQQPGELwgCEEDxhC8IAhBA8YQvCA\nIQQPGELwgCEEDxhC8IAhBA8YQvCAIQQPGELwgCEEDxhC8IAhBA8YQvCAIQQPGELwgCEEDxhC8IAh\nBA8YQvCAIQQPGELwgCEEDxhC8IAhBA8YUjT4dDqt6upqNTY2zi8bGhpSLBZTMplUMpnU+Ph44EMC\n8EfR4Pv7+18L2nEcDQ4OKpfLKZfLqbOzM9ABAfinaPCtra2qqqp6bbnneYENBCA4G0r5R8PDw/rp\np5+0Z88eHT58WG+++eZr67gLvo7//QLgr2l3WtPugxWv73jLfFxPTU2pu7tbN2/elCQ9efJEW7Zs\n0fPnz/Xll1/qvffe0xdffLF4o46jr0sYHvg3+tP7KuoRVuy/zomiR+Crvkq/detWOY6jyspKHTly\nRKOjo2saEEB4Vh38zMyMJGlubk5nz55VV1eX70MBCEbRc/h9+/bpwoULevbsmWpra3X8+HG5rqvr\n169r48aNamtr08DAQFizAlijZc/hS9oo5/AoI+V0Dl/SVXpgvVtPEfuJW2sBQwgeMITgAUMIHjCE\n4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITg\nAUMIHjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOABQwgeMITgAUMIHjCE4AFDCB4whOAB\nQwgeMITgAUNCDX4qzJ2FbCrqAQI2FfUAAZt2p6MeIRQE75OpqAcI2FTUAwRs2n0Q9Qih4JAeMITg\nAUMcz/M83zfqOH5vEsAKFUt6Q9g7BBAdDukBQwgeMITgAUMIHjCE4AFDCB4w5C+IE8yrMIP2IgAA\nAABJRU5ErkJggg==\n" } ], "prompt_number": 13 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 8 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 8 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 8 } ], "metadata": {} } ] }