Source code for mrftools.TreeReweightedBeliefPropagator

"""CountingNumberBeliefPropagator class."""
from random import shuffle

import numpy as np

from .BeliefPropagator import BeliefPropagator, logsumexp


[docs]class TreeReweightedBeliefPropagator(BeliefPropagator): def __init__(self, markov_net, tree_probabilities=None): """ Initialize a TRBP object with a Markov net and a dictionary of tree probabilities :param markov_net: Markov net to perform inference on. :type markov_net: MarkovNet :param tree_probabilities: Edge appearance probabilities for spanning forest distribution. If this parameter is not provided, this class assumes there are tree probabilities stored in the Markov net object. The probabilities should be provided as a dict with a key-value pair for each edge. :type tree_probabilities: dict """ if tree_probabilities: self._set_tree_probabilities(tree_probabilities) super(TreeReweightedBeliefPropagator, self).__init__(markov_net) def _set_tree_probabilities(self, tree_probabilities): self.tree_probabilities = tree_probabilities for (edge, prob) in list(tree_probabilities.items()): if edge[::-1] not in tree_probabilities: self.tree_probabilities[edge[::-1]] = prob
[docs] def compute_message(self, var, neighbor): """Compute the message from var to factor.""" # compute the product of all messages coming into var except the one from neighbor pair = (var, neighbor) adjusted_message_product = self.var_beliefs[var] - self.messages[(neighbor, var)] # partial log-sum-exp operation matrix = self.mn.get_potential((neighbor, var)) / self.tree_probabilities[pair] + adjusted_message_product # the dot product with ones is slightly faster than calling sum message = np.log(np.exp(matrix - matrix.max()).dot(np.ones(matrix.shape[1]))) # pseudo-normalize message message -= np.max(message) return message
[docs] def compute_bethe_entropy(self): entropy = 0.0 unary_entropy = dict() for var in self.mn.variables: unary_entropy[var] = -np.sum(np.exp(self.var_beliefs[var]) * np.nan_to_num(self.var_beliefs[var])) entropy += unary_entropy[var] for var in self.mn.variables: for neighbor in self.mn.neighbors[var]: if var < neighbor: pair_entropy = -np.sum( np.exp(self.pair_beliefs[(var, neighbor)]) * np.nan_to_num(self.pair_beliefs[(var, neighbor)])) mutual_information = unary_entropy[var] + unary_entropy[neighbor] - pair_entropy entropy -= self.tree_probabilities[(var, neighbor)] * mutual_information return entropy
[docs] def compute_beliefs(self): for var in self.mn.variables: belief = self.mn.unary_potentials[var] for neighbor in self.mn.get_neighbors(var): belief = belief + self.messages[(neighbor, var)] * self.tree_probabilities[(neighbor, var)] log_z = logsumexp(belief) belief = belief - log_z self.var_beliefs[var] = belief
[docs] def compute_pairwise_beliefs(self): for var in self.mn.variables: for neighbor in self.mn.get_neighbors(var): if var < neighbor: belief = self.mn.get_potential((var, neighbor)) / self.tree_probabilities[(var, neighbor)] # compute product of all messages to var except from neighbor var_message_product = self.var_beliefs[var] - self.messages[(neighbor, var)] belief = (belief.T + var_message_product).T # compute product of all messages to neighbor except from var neighbor_message_product = self.var_beliefs[neighbor] - self.messages[(var, neighbor)] belief = belief + neighbor_message_product log_z = logsumexp(belief) belief = belief - log_z self.pair_beliefs[(var, neighbor)] = belief self.pair_beliefs[(neighbor, var)] = belief.T