From 65f0ca09f0e97cefb7ac88b9dcc21cb34d807372 Mon Sep 17 00:00:00 2001 From: rfeinman Date: Tue, 18 Dec 2018 15:39:30 -0500 Subject: [PATCH] adding parallel option for max-product inference algorithm --- pystruct/inference/common.py | 10 ++++ pystruct/inference/maxprod.py | 108 ++++++++++++++++++++++++---------- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/pystruct/inference/common.py b/pystruct/inference/common.py index a43c2a43..d5f2274b 100644 --- a/pystruct/inference/common.py +++ b/pystruct/inference/common.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import numpy as np @@ -49,3 +50,12 @@ def compute_energy(unary_potentials, pairwise_potentials, edges, labels): for edge, pw in zip(edges, pairwise_potentials): energy += pw[labels[edge[0]], labels[edge[1]]] return energy + + +def parallel(func, x, n_jobs=None): + p = mp.Pool(n_jobs) + y = p.map(func, x) + p.close() + p.join() + + return y \ No newline at end of file diff --git a/pystruct/inference/maxprod.py b/pystruct/inference/maxprod.py index b08900e5..69bc93a4 100644 --- a/pystruct/inference/maxprod.py +++ b/pystruct/inference/maxprod.py @@ -1,7 +1,7 @@ import numpy as np from scipy import sparse -from .common import _validate_params +from .common import _validate_params, parallel from ..utils.graph_functions import is_forest @@ -20,7 +20,8 @@ def is_chain(edges, n_vertices): def inference_max_product(unary_potentials, pairwise_potentials, edges, - max_iter=30, damping=0.5, tol=1e-5, relaxed=None): + max_iter=30, damping=0.5, tol=1e-5, relaxed=None, + n_jobs=1): """Max-product inference. In case the edges specify a tree, dynamic programming is used @@ -58,7 +59,7 @@ def inference_max_product(unary_potentials, pairwise_potentials, edges, y = tree_max_product(unary_potentials, pairwise_potentials, edges) else: y = iterative_max_product(unary_potentials, pairwise_potentials, edges, - max_iter=max_iter, damping=damping) + max_iter=max_iter, damping=damping, n_jobs=n_jobs) return y @@ -126,39 +127,86 @@ def tree_max_product(unary_potentials, pairwise_potentials, edges): def iterative_max_product(unary_potentials, pairwise_potentials, edges, - max_iter=10, damping=.5, tol=1e-5): + max_iter=10, damping=.5, tol=1e-5, n_jobs=1): + assert type(n_jobs) == int and n_jobs > 0 + # global variables + global unary_potentials_g, pairwise_potentials_g, edges_g, damping_g, \ + messages, all_incoming + unary_potentials_g = unary_potentials + pairwise_potentials_g = pairwise_potentials + edges_g = edges + damping_g = damping n_edges = len(edges) n_vertices, n_states = unary_potentials.shape messages = np.zeros((n_edges, 2, n_states)) all_incoming = np.zeros((n_vertices, n_states)) for i in range(max_iter): diff = 0 - for e, (edge, pairwise) in enumerate(zip(edges, pairwise_potentials)): - # update message from edge[0] to edge[1] - update = (all_incoming[edge[0]] + pairwise.T + - unary_potentials[edge[0]] - - messages[e, 1]) - old_message = messages[e, 0].copy() - new_message = np.max(update, axis=1) - new_message -= np.max(new_message) - new_message = damping * old_message + (1 - damping) * new_message - messages[e, 0] = new_message - update = new_message - old_message - all_incoming[edge[1]] += update - diff += np.abs(update).sum() - - # update message from edge[1] to edge[0] - update = (all_incoming[edge[1]] + pairwise + - unary_potentials[edge[1]] - - messages[e, 0]) - old_message = messages[e, 1].copy() - new_message = np.max(update, axis=1) - new_message -= np.max(messages[e, 1]) - new_message = damping * old_message + (1 - damping) * new_message - messages[e, 1] = new_message - update = new_message - old_message - all_incoming[edge[0]] += update - diff += np.abs(update).sum() + if n_jobs == 1: + for e, (edge,pairwise) in enumerate(zip(edges,pairwise_potentials)): + # update message from edge[0] to edge[1] + update = (all_incoming[edge[0]] + pairwise.T + + unary_potentials[edge[0]] + - messages[e, 1]) + old_message = messages[e, 0].copy() + new_message = np.max(update, axis=1) + new_message -= np.max(new_message) + new_message = damping*old_message + (1-damping)*new_message + messages[e, 0] = new_message + update = new_message - old_message + all_incoming[edge[1]] += update + diff += np.abs(update).sum() + + # update message from edge[1] to edge[0] + update = (all_incoming[edge[1]] + pairwise + + unary_potentials[edge[1]] + - messages[e, 0]) + old_message = messages[e, 1].copy() + new_message = np.max(update, axis=1) + new_message -= np.max(messages[e, 1]) + new_message = damping*old_message + (1-damping)*new_message + messages[e, 1] = new_message + update = new_message - old_message + all_incoming[edge[0]] += update + diff += np.abs(update).sum() + else: + results = parallel(max_product_iteration, range(n_edges), n_jobs) + for e, (edge, result) in enumerate(zip(edges, results)): + a, b = edge + message_0, message_1, update_a, update_b, d = result + messages[e, 0] = message_0 + messages[e, 1] = message_1 + all_incoming[a] += update_a + all_incoming[b] += update_b + diff += d if diff < tol: break return np.argmax(all_incoming + unary_potentials, axis=1) + + +def max_product_iteration(e): + a, b = edges_g[e] + pairwise = pairwise_potentials_g[e] + # update message from edge[0] to edge[1] + update = (all_incoming[a] + pairwise.T + unary_potentials_g[a] - + messages[e, 1]) + old_message = messages[e, 0].copy() + new_message = np.max(update, axis=1) + new_message = new_message - np.max(new_message) + new_message = damping_g * old_message + (1 - damping_g) * new_message + message_0 = new_message + update_b = new_message - old_message + + # update message from edge[1] to edge[0] + update = (all_incoming[b] + pairwise + unary_potentials_g[b] - + messages[e, 0]) + old_message = messages[e, 1].copy() + new_message = np.max(update, axis=1) + new_message = new_message - np.max(messages[e, 1]) + new_message = damping_g * old_message + (1 - damping_g) * new_message + message_1 = new_message + update_a = new_message - old_message + + diff = np.abs(update_a).sum() + np.abs(update_b).sum() + + return (message_0, message_1, update_a, update_b, diff)