Source code for mrftools.EM

"""EM learner class."""
from .Learner import Learner
from .opt import *

[docs]class EM(Learner): """ Objects that perform expectation maximization for learning with latent variables. """ def __init__(self, inference_type): super(EM, self).__init__(inference_type)
[docs] def learn(self, weights, optimizer=ada_grad, callback=None, opt_args=None): """ Fit model parameters by alternating inference of latent variables and learning the best parameters to fit all variables. This method implements the variational expectation-maximization concept. :param weights: Initial weight vector. Can be used to warm start from a previous solution. :param optimizer: gradient-based optimization function, as defined in opt.py :param callback: callback function run during each iteration of the optimizer. The function receives the weights as input. Can be useful for diagnostics, live plotting, storing records, etc. :param opt_args: optimization arguments. Usually a dictionary of parameter values :return: learned weights """ old_weights = np.inf new_weights = weights self.start_time = time.time() while not np.allclose(old_weights, new_weights, rtol=1e-4, atol=1e-5): old_weights = new_weights self.e_step(new_weights) new_weights = self.m_step(new_weights, optimizer, callback, opt_args) return new_weights
[docs] def e_step(self, weights): self.label_expectations = self.calculate_expectations(weights, self.conditioned_belief_propagators, True)
[docs] def m_step(self, weights, optimizer=ada_grad, callback=None, opt_args=None): res = optimizer(self.objective, self.gradient, weights, args=opt_args, callback=callback) return res