Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pystruct/inference/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing as mp
import numpy as np


Expand Down Expand Up @@ -49,3 +50,12 @@ def compute_energy(unary_potentials, pairwise_potentials, edges, labels):
for edge, pw in zip(edges, pairwise_potentials):
energy += pw[labels[edge[0]], labels[edge[1]]]
return energy


def parallel(func, x, n_jobs=None):
p = mp.Pool(n_jobs)
y = p.map(func, x)
p.close()
p.join()

return y
108 changes: 78 additions & 30 deletions pystruct/inference/maxprod.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from scipy import sparse

from .common import _validate_params
from .common import _validate_params, parallel
from ..utils.graph_functions import is_forest


Expand All @@ -20,7 +20,8 @@ def is_chain(edges, n_vertices):


def inference_max_product(unary_potentials, pairwise_potentials, edges,
max_iter=30, damping=0.5, tol=1e-5, relaxed=None):
max_iter=30, damping=0.5, tol=1e-5, relaxed=None,
n_jobs=1):
"""Max-product inference.

In case the edges specify a tree, dynamic programming is used
Expand Down Expand Up @@ -58,7 +59,7 @@ def inference_max_product(unary_potentials, pairwise_potentials, edges,
y = tree_max_product(unary_potentials, pairwise_potentials, edges)
else:
y = iterative_max_product(unary_potentials, pairwise_potentials, edges,
max_iter=max_iter, damping=damping)
max_iter=max_iter, damping=damping, n_jobs=n_jobs)
return y


Expand Down Expand Up @@ -126,39 +127,86 @@ def tree_max_product(unary_potentials, pairwise_potentials, edges):


def iterative_max_product(unary_potentials, pairwise_potentials, edges,
max_iter=10, damping=.5, tol=1e-5):
max_iter=10, damping=.5, tol=1e-5, n_jobs=1):
assert type(n_jobs) == int and n_jobs > 0
# global variables
global unary_potentials_g, pairwise_potentials_g, edges_g, damping_g, \
messages, all_incoming
unary_potentials_g = unary_potentials
pairwise_potentials_g = pairwise_potentials
edges_g = edges
damping_g = damping
n_edges = len(edges)
n_vertices, n_states = unary_potentials.shape
messages = np.zeros((n_edges, 2, n_states))
all_incoming = np.zeros((n_vertices, n_states))
for i in range(max_iter):
diff = 0
for e, (edge, pairwise) in enumerate(zip(edges, pairwise_potentials)):
# update message from edge[0] to edge[1]
update = (all_incoming[edge[0]] + pairwise.T +
unary_potentials[edge[0]]
- messages[e, 1])
old_message = messages[e, 0].copy()
new_message = np.max(update, axis=1)
new_message -= np.max(new_message)
new_message = damping * old_message + (1 - damping) * new_message
messages[e, 0] = new_message
update = new_message - old_message
all_incoming[edge[1]] += update
diff += np.abs(update).sum()

# update message from edge[1] to edge[0]
update = (all_incoming[edge[1]] + pairwise +
unary_potentials[edge[1]]
- messages[e, 0])
old_message = messages[e, 1].copy()
new_message = np.max(update, axis=1)
new_message -= np.max(messages[e, 1])
new_message = damping * old_message + (1 - damping) * new_message
messages[e, 1] = new_message
update = new_message - old_message
all_incoming[edge[0]] += update
diff += np.abs(update).sum()
if n_jobs == 1:
for e, (edge,pairwise) in enumerate(zip(edges,pairwise_potentials)):
# update message from edge[0] to edge[1]
update = (all_incoming[edge[0]] + pairwise.T +
unary_potentials[edge[0]]
- messages[e, 1])
old_message = messages[e, 0].copy()
new_message = np.max(update, axis=1)
new_message -= np.max(new_message)
new_message = damping*old_message + (1-damping)*new_message
messages[e, 0] = new_message
update = new_message - old_message
all_incoming[edge[1]] += update
diff += np.abs(update).sum()

# update message from edge[1] to edge[0]
update = (all_incoming[edge[1]] + pairwise +
unary_potentials[edge[1]]
- messages[e, 0])
old_message = messages[e, 1].copy()
new_message = np.max(update, axis=1)
new_message -= np.max(messages[e, 1])
new_message = damping*old_message + (1-damping)*new_message
messages[e, 1] = new_message
update = new_message - old_message
all_incoming[edge[0]] += update
diff += np.abs(update).sum()
else:
results = parallel(max_product_iteration, range(n_edges), n_jobs)
for e, (edge, result) in enumerate(zip(edges, results)):
a, b = edge
message_0, message_1, update_a, update_b, d = result
messages[e, 0] = message_0
messages[e, 1] = message_1
all_incoming[a] += update_a
all_incoming[b] += update_b
diff += d
if diff < tol:
break
return np.argmax(all_incoming + unary_potentials, axis=1)


def max_product_iteration(e):
a, b = edges_g[e]
pairwise = pairwise_potentials_g[e]
# update message from edge[0] to edge[1]
update = (all_incoming[a] + pairwise.T + unary_potentials_g[a] -
messages[e, 1])
old_message = messages[e, 0].copy()
new_message = np.max(update, axis=1)
new_message = new_message - np.max(new_message)
new_message = damping_g * old_message + (1 - damping_g) * new_message
message_0 = new_message
update_b = new_message - old_message

# update message from edge[1] to edge[0]
update = (all_incoming[b] + pairwise + unary_potentials_g[b] -
messages[e, 0])
old_message = messages[e, 1].copy()
new_message = np.max(update, axis=1)
new_message = new_message - np.max(messages[e, 1])
new_message = damping_g * old_message + (1 - damping_g) * new_message
message_1 = new_message
update_a = new_message - old_message

diff = np.abs(update_a).sum() + np.abs(update_b).sum()

return (message_0, message_1, update_a, update_b, diff)