Source code for mrftools.MaxProductBeliefPropagator
"""Class to run max-product belief propagation."""
import numpy as np
from .MatrixBeliefPropagator import MatrixBeliefPropagator, sparse_dot, logsumexp
[docs]class MaxProductBeliefPropagator(MatrixBeliefPropagator):
"""
Class to run inference of the most likely state via max-product belief propagation.
"""
def __init__(self, markov_net):
"""
Initialize a max-product belief propagator.
:param markov_net: MarkovNet object encoding the probability distribution
:type markov_net: MarkovNet
"""
super(MaxProductBeliefPropagator, self).__init__(markov_net)
[docs] def compute_beliefs(self):
if not self.fully_conditioned:
max_marginals = self.mn.unary_mat + self.augmented_mat
max_marginals += sparse_dot(self.message_mat, self.mn.message_to_map)
states = max_marginals.argmax(0)
self.belief_mat = -np.inf * np.ones(max_marginals.shape)
self.belief_mat[states, range(self.belief_mat.shape[1])] = 0
[docs] def compute_pairwise_beliefs(self):
if not self.fully_conditioned:
adjusted_message_prod = self.belief_mat[:, self.mn.message_from] \
- np.hstack((self.message_mat[:, self.mn.num_edges:],
self.message_mat[:, :self.mn.num_edges]))
to_messages = adjusted_message_prod[:, :self.mn.num_edges].reshape(
(self.mn.max_states, 1, self.mn.num_edges))
from_messages = adjusted_message_prod[:, self.mn.num_edges:].reshape(
(1, self.mn.max_states, self.mn.num_edges))
max_marginals = self.mn.edge_pot_tensor[:, :, self.mn.num_edges:] + to_messages + from_messages
self.pair_belief_tensor = np.where(max_marginals == max_marginals.max((0, 1)), 0, -np.inf)
[docs] def update_messages(self):
belief_mat = self.mn.unary_mat + self.augmented_mat
belief_mat += sparse_dot(self.message_mat, self.mn.message_to_map)
belief_mat -= logsumexp(belief_mat, 0)
adjusted_message_prod = self.mn.edge_pot_tensor - np.hstack((self.message_mat[:, self.mn.num_edges:],
self.message_mat[:, :self.mn.num_edges]))
adjusted_message_prod += belief_mat[:, self.mn.message_from]
messages = np.squeeze(adjusted_message_prod.max(1))
messages = np.nan_to_num(messages - messages.max(0))
change = np.sum(np.abs(messages - self.message_mat))
self.message_mat = messages
return change