Source code for mrftools.ConvexBeliefPropagator

"""Convexified Belief Propagation Class"""
import numpy as np

from .MatrixBeliefPropagator import MatrixBeliefPropagator, logsumexp, sparse_dot


[docs]class ConvexBeliefPropagator(MatrixBeliefPropagator): """ Class to perform convexified belief propagation based on counting numbers. The class allows for non-Bethe counting numbers for the different factors in the MRF. If the factors are all non-negative, then the adjusted Bethe free energy is convex, providing better guarantees about the convergence and bounds of the primal and dual objective values. """ def __init__(self, markov_net, counting_numbers=None): """ Initialize a convexified belief propagator. :param markov_net: MarkovNet object encoding the probability distribution :type markov_net: MarkovNet :param counting_numbers: a dictionary with an entry for each variable and edge such that the value is a float representing the counting number to use in computing the convexified Bethe formulas and corresponding message passing updates. :type counting_numbers: dict """ super(ConvexBeliefPropagator, self).__init__(markov_net) self.unary_counting_numbers = np.ones(len(self.mn.variables)) self.edge_counting_numbers = np.ones(2 * self.mn.num_edges) default_counting_numbers = dict() for var in markov_net.variables: default_counting_numbers[var] = 1 for neighbor in markov_net.neighbors[var]: if var < neighbor: default_counting_numbers[(var, neighbor)] = 1 if counting_numbers: self._set_counting_numbers(counting_numbers) else: self._set_counting_numbers(default_counting_numbers) def _set_counting_numbers(self, counting_numbers): """ Store the provided counting numbers and set up the associated vectors for the ordered variable representation. :param counting_numbers: a dictionary with an entry for each variable and edge with counting number values :type counting_numbers: dict :return: None """ self.edge_counting_numbers = np.zeros(2 * self.mn.num_edges) for edge, i in self.mn.message_index.items(): reversed_edge = edge[::-1] if edge in counting_numbers: self.edge_counting_numbers[i] = counting_numbers[edge] self.edge_counting_numbers[i + self.mn.num_edges] = counting_numbers[edge] elif reversed_edge in counting_numbers: self.edge_counting_numbers[i] = counting_numbers[reversed_edge] self.edge_counting_numbers[i + self.mn.num_edges] = counting_numbers[reversed_edge] else: raise KeyError('Edge %s was not assigned a counting number.' % repr(edge)) self.unary_counting_numbers = np.zeros((len(self.mn.variables), 1)) for var, i in self.mn.var_index.items(): self.unary_counting_numbers[i] = counting_numbers[var] self.unary_coefficients = self.unary_counting_numbers.copy() for edge, i in self.mn.message_index.items(): self.unary_coefficients[self.mn.var_index[edge[0]]] += self.edge_counting_numbers[i] self.unary_coefficients[self.mn.var_index[edge[1]]] += self.edge_counting_numbers[i]
[docs] def compute_bethe_entropy(self): if self.fully_conditioned: entropy = 0 else: entropy = - np.sum(self.edge_counting_numbers[:self.mn.num_edges] * (np.nan_to_num(self.pair_belief_tensor) * np.exp(self.pair_belief_tensor))) \ - np.sum(self.unary_counting_numbers.T * (np.nan_to_num(self.belief_mat) * np.exp(self.belief_mat))) return entropy
[docs] def update_messages(self): self.compute_beliefs() adjusted_message_prod = self.mn.edge_pot_tensor - np.hstack((self.message_mat[:, self.mn.num_edges:], self.message_mat[:, :self.mn.num_edges])) adjusted_message_prod /= self.edge_counting_numbers adjusted_message_prod += self.belief_mat[:, self.mn.message_from] messages = np.squeeze(logsumexp(adjusted_message_prod, 1)) * self.edge_counting_numbers messages = np.nan_to_num(messages - messages.max(0)) change = np.sum(np.abs(messages - self.message_mat)) self.message_mat = messages return change
[docs] def compute_beliefs(self): if not self.fully_conditioned: self.belief_mat = self.mn.unary_mat + self.augmented_mat self.belief_mat += sparse_dot(self.message_mat, self.mn.message_to_map) self.belief_mat /= self.unary_coefficients.T log_z = logsumexp(self.belief_mat, 0) self.belief_mat = self.belief_mat - log_z
[docs] def compute_pairwise_beliefs(self): if not self.fully_conditioned: adjusted_message_prod = self.belief_mat[:, self.mn.message_from] \ - np.nan_to_num(np.hstack((self.message_mat[:, self.mn.num_edges:], self.message_mat[:, :self.mn.num_edges])) / self.edge_counting_numbers) to_messages = adjusted_message_prod[:, :self.mn.num_edges].reshape( (self.mn.max_states, 1, self.mn.num_edges)) from_messages = adjusted_message_prod[:, self.mn.num_edges:].reshape( (1, self.mn.max_states, self.mn.num_edges)) beliefs = self.mn.edge_pot_tensor[:, :, self.mn.num_edges:] / \ self.edge_counting_numbers[self.mn.num_edges:] + to_messages + from_messages beliefs -= logsumexp(beliefs, (0, 1)) self.pair_belief_tensor = beliefs