Source code for mrftools.GibbsSampler
"""Gibbs sampling class"""
from __future__ import division
import random
from collections import Counter
import numpy as np
import pandas as pd
from scipy.misc import logsumexp
[docs]class GibbsSampler(object):
"""Object that can run Gibbs sampling on a MarkovNet"""
def __init__(self, markov_net):
"""Initialize belief propagator for markov_net."""
self.mn = markov_net
self.states = dict()
self.unary_weights = dict()
self.samples = list()
@staticmethod
[docs] def generate_state(weight):
"""Generate state according to the given weight"""
r = random.uniform(0, 1)
# Sum = sum(weight.values())
Sum = sum(weight)
rnd = r * Sum
for i in range(len(weight)):
rnd = rnd - weight[i]
if rnd < 0:
return i
[docs] def init_states(self, seed=None):
"""
Initialize the state of each node.
:param seed: random seed
"""
if seed is not None:
np.random.seed(seed)
for var in self.mn.variables:
weight = self.mn.unary_potentials[var]
weight = np.exp(weight - logsumexp(weight))
self.unary_weights[var] = weight
self.states[var] = self.generate_state(self.unary_weights[var])
[docs] def update_states(self):
"""Update the state of each node based on neighbor states."""
for var in self.mn.variables:
weight = self.mn.unary_potentials[var]
for neighbor in self.mn.neighbors[var]:
weight = weight + self.mn.get_potential((var, neighbor))[:, self.states[neighbor]]
weight = np.exp(weight - logsumexp(weight))
self.states[var] = self.generate_state(weight)
[docs] def burn_in(self, iters):
"""
Run the state update procedure until mixed.
:param iters: number of iterations for mixing
"""
for i in range(0, iters):
self.update_states()
[docs] def sampling(self, num):
"""
Run sampling
:param num: number of samples to collect
"""
for i in range(0, num):
self.update_states()
self.samples.append(self.states.copy())
# for i in range(0, s-1):
# self.update_states()
[docs] def gibbs_sampling(self, burn_in, num):
"""
Run Gibbs sampling
:param burn_in: number of burn-in samples to discard
:type burn_in: int
:param num: number of samples to collect once burn-in phase is done
:type num: int
"""
self.burn_in(burn_in)
self.sampling(num)
[docs] def count_occurrences(self, var):
"""
Count the number of times in our samples the variable was in each state.
:param var: variable to count the states of
:type var: object
:return: count array of state occurrences
:rtype: arraylike
"""
counts = Counter(pd.DataFrame(self.samples)[var])
count_array = np.asarray(list(counts.values()))
return count_array