什么是期望最大化技术的直观解释? [关闭]

问题描述 投票:98回答:8

期望最大化(EM)是一种对数据进行分类的概率方法。如果我错了,请纠正我,如果它不是分类器。

这种EM技术的直观解释是什么?什么是expectation在这里和什么是maximized

machine-learning cluster-analysis data-mining mathematical-optimization expectation-maximization
8个回答
105
投票

注意:这个答案背后的代码可以找到here


假设我们从两个不同的组(红色和蓝色)中采集了一些数据:

enter image description here

在这里,我们可以看到哪个数据点属于红色或蓝色组。这样可以轻松找到表征每个组的参数。例如,红色组的平均值约为3,蓝色组的平均值约为7(如果需要,我们可以找到确切的方法)。

一般而言,这被称为最大似然估计。给定一些数据,我们计算最能解释该数据的参数(或参数)的值。

现在想象一下,我们无法看到从哪个组中采样了哪个值。一切看起来都是紫色的:

enter image description here

在这里,我们知道有两组值,但我们不知道任何特定值属于哪一组。

我们还能估算出最适合这些数据的红色组和蓝色组的平均值吗?

是的,我们经常可以!期望最大化为我们提供了一种方法。算法背后的一般思路是这样的:

  1. 首先估计每个参数可能是什么。
  2. 计算每个参数产生数据点的可能性。
  3. 根据参数产生的可能性,计算每个数据点的权重,指示它是红色还是更蓝。将权重与数据相结合(期望)。
  4. 使用权重调整数据(最大化)计算对参数的更好估计。
  5. 重复步骤2到4,直到参数估计收敛(过程停止产生不同的估计)。

这些步骤需要进一步解释,因此我将逐步解决上述问题。

Example: estimating mean and standard deviation

我将在此示例中使用Python,但如果您不熟悉此语言,则代码应该相当容易理解。

假设我们有两个组,红色和蓝色,其值分布如上图所示。具体来说,每个组都包含从normal distribution中提取的值,其中包含以下参数:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

这是再次显示这些红色和蓝色组的图像(以免您不必向上滚动):

enter image description here

当我们可以看到每个点的颜色(即它属于哪个组)时,很容易估计每个组的平均值和标准偏差。我们只将红色和蓝色值传递给NumPy中的内置函数。例如:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

但是如果我们看不到这些点的颜色呢?也就是说,不是红色或蓝色,每个点都是紫色的。

为了尝试恢复红色和蓝色组的均值和标准差参数,我们可以使用期望最大化。

我们的第一步(上面的步骤1)是猜测每个组的平均值和标准偏差的参数值。我们不必聪明地猜测;我们可以挑选任何我们喜欢的数字:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

这些参数估计产生如下所示的钟形曲线:

enter image description here

这些都是糟糕的估计。例如,两种方式(垂直虚线)看起来远离任何类型的“中间”,用于合理的点组。我们希望改进这些估算。

下一步(步骤2)是计算每个数据点出现在当前参数猜测下的可能性:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

在这里,我们简单地将每个数据点放入probability density function进行正态分布,使用我们当前的红色和蓝色均值和标准差的猜测。这告诉我们,例如,根据我们目前的猜测,1.761处的数据点比蓝色(0.00003)更可能是红色(0.189)。

对于每个数据点,我们可以将这两个似然值转换为权重(步骤3),以便它们总和为1,如下所示:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

根据我们当前的估计值和我们新计算的权重,我们现在可以计算红色和蓝色组的平均值和标准偏差的新估计值(步骤4)。

我们使用所有数据点计算均值和标准差,但使用不同的权重:一次为红色权重,一次为蓝色权重。

直觉的关键点是数据点上颜色的权重越大,数据点就越会影响该颜色参数的下一个估计值。这具有在正确方向上“拉”参数的效果。

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

我们对参数有了新的估计。为了再次改进它们,我们可以跳回到步骤2并重复该过程。我们这样做直到估计收敛,或者在执行了一些迭代之后(步骤5)。

对于我们的数据,此过程的前五次迭代看起来像这样(最近的迭代具有更强的外观):

enter image description here

我们看到均值已经收敛于某些值,曲线的形状(由标准偏差控制)也变得更加稳定。

如果我们继续进行20次迭代,我们最终会得到以下结果:

enter image description here

EM过程已收敛到以下值,结果非常接近实际值(我们可以看到颜色 - 没有隐藏变量):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

在上面的代码中,您可能已经注意到,使用先前迭代对均值的估计来计算标准偏差的新估计。最终,如果我们首先计算均值的新值并不重要,因为我们只是在一些中心点周围找到值的(加权)方差。我们仍然会看到参数的估计收敛。


35
投票

EM是一种算法,用于在模型中的某些变量未被观察时(即,当您有潜在变量时)最大化似然函数。

您可能会公平地问,如果我们只是尝试最大化函数,为什么我们不使用现有的机器来最大化函数。好吧,如果你试图通过取导数并将它们设置为零来最大化这个,你会发现在许多情况下,一阶条件没有解决方案。有一个鸡和蛋的问题,要解决您的模型参数,您需要知道未观察到的数据的分布;但是,未观察到的数据的分布是模型参数的函数。

E-M试图通过迭代猜测未观察到的数据的分布来解决这个问题,然后通过最大化实际似然函数的下限来估计模型参数,并重复直到收敛:

EM算法

从猜测模型参数的值开始

E步骤:对于每个具有缺失值的数据点,使用模型方程式来解决缺失数据的分布,给出您当前对模型参数的猜测并给出观察到的数据(请注意,您正在为每个缺失的分布求解)价值,而非预期价值)。现在我们有了每个缺失值的分布,我们可以计算似然函数相对于未观察到的变量的期望。如果我们对模型参数的猜测是正确的,那么这个预期的可能性将是我们观察到的数据的实际可能性;如果参数不正确,它只是一个下限。

M步骤:既然我们已经得到了一个没有未观察到的变量的预期似然函数,那么就像在完全观察到的情况下一样最大化函数,以获得模型参数的新估计。

重复直到收敛。


26
投票

以下是了解期望最大化算法的简单方法:

1-阅读Do和Batzoglou的这个EM tutorial paper

2-你的头脑中可能有问号,看看这个数学堆栈交换page的解释。

3-看看我在Python中编写的代码,该代码解释了第1项的EM教程文章中的示例:

警告:代码可能是凌乱/次优的,因为我不是Python开发人员。但它完成了这项工作。

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

16
投票

从技术上讲,术语“EM”有点不明确,但我假设您参考高斯混合建模聚类分析技术,这是一般EM原理的一个实例。

实际上,EM聚类分析不是分类器。我知道有些人认为聚类是“无监督分类”,但实际上聚类分析是完全不同的。

关键差异,以及人们对聚类分析总是存在的误解是:在集群分析中,没有“正确的解决方案”。它是一种知识发现方法,它实际上是为了找到新的东西!这使得评估非常棘手。它通常使用已知的分类作为参考进行评估,但这并不总是合适的:您所拥有的分类可能会或可能不会反映数据中的内容。

让我举个例子:您拥有大量客户数据,包括性别数据。将此数据集拆分为“男性”和“女性”的方法在将其与现有类进行比较时是最佳的。以“预测”的方式思考这是好的,对于新用户,您现在可以预测他们的性别。在“知识发现”的思维方式中,这实际上是不好的,因为你想在数据中发现一些新的结构。一种方法,例如将数据分成老年人和孩子,但是对于男/女班级而言,得分会更差。然而,这将是一个很好的聚类结果(如果没有给出年龄)。

现在回到EM。本质上,它假设您的数据由多个多元正态分布组成(请注意,这是一个非常强大的假设,特别是当您修复群集的数量时!)。然后,它通过交替改进模型和模型的对象分配来尝试为此找到局部最优模型。

为了在分类上下文中获得最佳结果,请选择大于类数的聚类数,或者甚至仅将聚类应用于单个类(以确定类中是否存在某些结构!)。

假设您想训练分类器来区分“汽车”,“自行车”和“卡车”。假设数据恰好由3个正态分布组成,几乎没有用处。但是,您可以假设有多种类型的汽车(以及卡车和自行车)。因此,不是为这三个类训练分类器,而是将汽车,卡车和自行车分成10个集群(或者10辆汽车,3辆卡车和3辆自行车,无论如何),然后训练分类器分开这30个班级,然后将类结果合并回原始类。您可能还会发现有一个群集特别难以分类,例如Trikes。他们有点车,有点自行车。或者送货卡车,更像超大型汽车而不是卡车。


2
投票

其他答案很好,我会尝试提供另一个视角,并解决问题的直观部分。

EM (Expectation-Maximization) algorithm是使用duality的一类迭代算法的变体

摘录(强调我的):

在数学中,一般来说,二元性将概念,定理或数学结构以一对一的方式转换为其他概念,定理或结构,通常(但不总是)通过对合操作:如果是对偶A是B,那么B的对偶是A.这样的回归有时具有固定点,因此A的对偶是A本身

通常,对象A的双B以某种方式与A相关,以保持一些对称性或兼容性。例如AB = const

使用二元性(在先前意义上)的迭代算法的示例是:

  1. Euclidean algorithm for Greatest Common Divisor, and its variants
  2. Gram–Schmidt Vector Basis algorithm and variants
  3. Arithmetic Mean - Geometric Mean Inequality, and its variants
  4. Expectation-Maximization algorithm and its variants(另见here for an information-geometric view
  5. (..其他类似算法..)

以类似的方式,the EM algorithm can also be seen as two dual maximization steps

.. [EM]被视为最大化参数和分布在未观察到的变量上的联合函数.E-步骤使该函数在未观察到的变量上的分布最大化;关于参数的M步骤..

在使用对偶性的迭代算法中,存在均衡(或固定)收敛点的显式(或隐式)假设(对于EM,使用Jensen的不等式证明了这一点)

所以这些算法的大纲是:

  1. 类似E的步骤:找到关于给定y保持不变的最佳解x。
  2. 类似M的步骤(双重):找到关于x(在前一步骤中计算)保持不变的最佳解y。
  3. 终止/收敛步骤:使用x,y的更新值重复步骤1,2,直到收敛(或达到指定的迭代次数)

注意,当这样的算法收敛到(全局)最优时,它已经找到了在两种意义上最好的配置(即在x域/参数和y域/参数中)。然而,该算法可以找到局部最优而不是全局最优。

我会说这是算法大纲的直观描述

对于统计论证和应用,其他答案给出了很好的解释(请参阅本答复中的参考)


2
投票

接受的答案引用了Chuong EM Paper,它在解释EM方面做得不错。还有一个youtube video更详细地解释了这篇论文。

回顾一下,这是一个场景:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

在第一个试验的问题的情况下,直觉上我们认为B生成它,因为头部的比例很好地匹配B的偏差......但是这个值只是猜测,所以我们不能确定。

考虑到这一点,我想像这样考虑EM解决方案:

  • 每次翻转试验都会对其最喜欢的硬币进行“投票” 这是基于每枚硬币与其分布的匹配程度 或者,从硬币的角度来看,人们期望看到这个试验相对于另一个硬币(基于对数可能性)。
  • 根据每个试验对每枚硬币的喜好程度,它可以更新该硬币参数的猜测(偏差)。 审判越喜欢硬币,就越能更新硬币的偏见以反映自己的硬币! 基本上,通过在所有试验中组合这些加权更新来更新硬币的偏差,这个过程称为(最大化),这是指在给定一组试验的情况下试图获得每个硬币偏差的最佳猜测。

这可能过于简单化(甚至在某些层面上甚至是根本上的错误),但我希望这在直观的层面上有所帮助!


1
投票

EM用于最大化具有潜在变量Z的模型Q的可能性。

这是一个迭代优化。

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-步骤:给定Z的当前估计计算预期的对数似然函数

m-step:找到最大化此Q的theta

GMM示例:

e-step:在给定当前gmm参数估计的情况下估计每个数据点的标签分配

m-step:在给定新​​标签分配的情况下最大化新的theta

K-means也是一种EM算法,在K-means上有很多解释动画。


1
投票

使用Do和Batzoglou在Zhubarb的回答中引用的相同文章,我在Java中实现了针对该问题的EM。对他的答案的评论表明,算法陷入局部最优,如果参数thetaA和thetaB相同,这也会在我的实现中发生。

下面是我的代码的标准输出,显示了参数的收敛性。

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

下面是我用Java解决问题的Java实现(Do and Batzoglou,2008)。实现的核心部分是运行EM的循环,直到参数收敛为止。

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

以下是整个代码。

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
© www.soinside.com 2019 - 2024. All rights reserved.