混合ガウス分布のEMアルゴリズム、C++実装例(2次元のみ)

f:id:shindannin:20140301205110p:plain

先日のTopCoderマラソンマッチ 「Octave Classifier(カテゴリー分けする機械学習系の問題)」で、使わずじまいに終わったC++のコードをアップします。

混合ガウス分布のEMアルゴリズムについて、全く知らなかったので、

を参考にしました。ありがとうございます。基本的に、Nakatani Shuyoさんのコードの移植ですので、比較すると分かりやすいかもしれません。ただ数学関数は自前になってます。2次元限定ですので、ご了承ください。

Rやpython等で実装している方が既にいるので、C++での需要はないかもしれませんが、

  • 型が分かりやすい
  • ライブラリがなくても、単独で動く。

というメリットはあるので、もしかしたら少しは参考になるかもしれません。

実際にideone.comで実行した結果はhttps://ideone.com/LXH8viです。入力データセットは、Rに付属しているfaithfulというものです(Nakatani Shuyoさんのページと同じです)。3クラスで実行しました。ideoneの出力データをそのままRでplotしたものが、最初の画像になります。どうみても2クラスに分けるためのデータセットですが、テストということで。

#include <iostream>
#include <vector>
#include <cassert>
#include <cmath>
#include <numeric>
#include <stdio.h>
#include <stdlib.h>

using namespace std;
#define SZ(a) ((int)((a).size()))

// 注:データの次元D=2のときしか対応していません!
typedef vector <double> Vec;
typedef vector <Vec> Mtx;

class EMGaussianMixtures
{
private:
	int N; // データ個数
	int D; // データの次元
	int K; // クラスの種類数
	Mtx	m_xx;				// x:正規化したデータ m_xx[N][D]
	Mtx	m_mu;				// μ:平均             m_mu[K][D]
	Vec	m_mix;				// π:混合率           m_mix[K]
	vector < Mtx >	m_sig;	// Σ:共分散           m_sig[K][D][D]
	Mtx	m_gamma_nk;			// γ:負担率			m_gamma_nk[N][K]

public:

	// 初期化
	// データセット		dataFrame[N][D]
	// クラスの種類数	numClasses -> K
	void initialize(const Mtx& dataFrame, int numClasses)
	{
		N = SZ(dataFrame);
		D = SZ(dataFrame[0]);
		K = numClasses;
		assert(N>=1);
		assert(D==2);
		assert(K>=1);

		// データの正規化
		m_xx = dataFrame;
		normalize(m_xx);

		// 平均、共分散、混合率の初期値(正規乱数)
		m_mu	= Mtx(K, Vec(D));

		for(int row=0;row<K;++row)
		{
			for(int col=0;col<D;++col)
			{
				m_mu[row][col]= getNormRand(); // たぶん、必須ではないけど、元のコードに合わせました。
			}
		}

		m_mix = Vec(K, 1.0/K);

		m_sig = vector < Mtx >(K, Mtx(D, Vec(D)));
		for (int k = 0; k < K; ++k)
		{
			for (int rc=0; rc<D; ++rc)
			{
				m_sig[k][rc][rc]=1.0;
			}
		}
	}

	// E stepとM stepを1回ずつ実行
	void run()
	{
		m_gamma_nk = calcEstep(m_xx, m_mu, m_mix, m_sig);

		Mtx	new_mu;
		Vec new_mix;
		vector < Mtx > new_sig;

		calcMstep(new_mu, new_mix, new_sig, m_xx, m_gamma_nk);

		m_mu = new_mu;
		m_mix = new_mix;
		m_sig = new_sig;
	}

	void printGammaNK() const
	{
		for(int n=0;n<N;n++)
		{
			for(int k=0;k<K;k++)
			{
				printf("%.10f",m_gamma_nk[n][k]);
				if(k<K-1)
				{
					printf(", ");
				}
			}
			printf("\n");
		}
	}

private:
	// 正規化。平均0、不偏標準偏差1にする。
	void normalize( Mtx& xx ) const
	{
		for (int col=0;col<D;++col)
		{
			double mean = 0.0;	// 平均
			for (int row=0;row<N;++row)
			{
				mean += xx[row][col];
			}
			mean /= N;

			double sd	= 0.0;	// 不偏標準偏差
			for (int row=0;row<N;++row)
			{
				const double diff = xx[row][col]-mean;
				sd += diff * diff;
			}
			sd	/= (N-1);	// 標準偏差じゃなくて不偏標準偏差なので、NじゃなくてN-1で割るのに注意。
			sd	= sqrt(sd);

			// 平均を引いた後に、不偏標準偏差で割る。
			for (int row=0;row<N;++row)
			{
				xx[row][col] = (xx[row][col]-mean)/sd;
			}
		}
	}

	// 行列と行列の積
	inline Mtx mul(const Mtx& a, const Mtx& b) const
	{
		vector <vector <double> > c(SZ(a), vector<double>(SZ(b[0])));
		for (int i = 0; i < SZ(a); i++)
		{
			for (int k = 0; k < SZ(b); k++)
			{
				for (int j = 0; j < SZ(b[0]); j++)
				{
					c[i][j] += a[i][k]*b[k][j];
				}
			}
		}
		return c;
	}

	// スカラーと行列の積
	inline Mtx mulScalar(const Mtx& a, const double scalar) const
	{
		Mtx ret(a);
		for (int i = 0; i < SZ(ret); i++)
		{
			for (int k = 0; k < SZ(ret[0]); k++)
			{
				ret[i][k] *= scalar;
			}
		}
		return ret;
	}

	// スカラーとベクトルの積
	inline Vec mulScalar(const Vec& a, const double scalar) const
	{
		Vec ret(a);
		for (int i = 0; i < SZ(ret); i++)
		{
			ret[i] *= scalar;
		}
		return ret;
	}

	// 行列の転置
	inline Mtx transpose( const Mtx& vs ) const
	{
		const int H = SZ(vs);
		const int W = SZ(vs[0]);

		Mtx ret(W, Vec(H) );
		for (int y = 0; y < W; y++)
		{
			for (int x = 0; x < H; x++)
			{
				ret[y][x] = vs[x][y];
			}
		}

		return ret;
	}

	// 2*2の行列式を求める
	inline double det(const Mtx& m) const
	{
		return m[0][0]*m[1][1]-m[0][1]*m[1][0];
	}

	// 2*2の逆行列を求める
	inline Mtx solve(const Mtx& m ) const
	{
		vector < vector <double> > ret(m);
		swap(ret[0][0],ret[1][1]);
		ret[0][1] = -ret[0][1];
		ret[1][0] = -ret[1][0];
		ret = mulScalar(ret,1.0/abs(det(m)));
		return ret;
	}

	// 多次元正規分布密度関数を求める
	double getDMNorm(const Vec& x, const Vec& mu, const Mtx& sig) const
	{
		Vec x_mu(x);	// x - mu
		for (int i = 0; i < SZ(x_mu); i++)
		{
			x_mu[i] -= mu[i];
		}

		const Mtx inv = solve(sig);

		// D=2決め打ちなので、mul(mul(transpose(x_mu),solve(sig)),x_mu)の部分を展開し、Cとした。
		const double C = x_mu[0]*(x_mu[0]*inv[0][0]+x_mu[1]*inv[1][0]) + x_mu[1]*(x_mu[0]*inv[0][1]+x_mu[1]*inv[1][1]);

		double ret = 1/(sqrt(pow(2.0*M_PI,D) * det(sig))) * exp( -0.5 * C );

		return ret;
	}

	// Eステップ。返り値は gamma_nk[N][K]
	Mtx calcEstep( const Mtx& xx, const Mtx& mu, const Vec& mix, const vector < Mtx >& sig) const
	{
		Mtx ret;

		for(int row=0;row<N;++row)
		{
			Vec numer(K);
			for(int k=0;k<K;k++)
			{
				numer[k]=mix[k]*getDMNorm(xx[row], mu[k], sig[k]);
			}
			const double sum_numer = accumulate(numer.begin(),numer.end(),0.0);

			for (int k=0; k < K; k++)
			{
				numer[k]/=sum_numer;
			}

			ret.push_back(numer);
		}

		return ret;
	}

	// Mステップ。次のステップのmu,mix,sigを求める。
	void calcMstep(	Mtx& new_mu, Vec& new_mix, vector < Mtx >& new_sig, const Mtx& xx, const Mtx& gamma_nk) const	
	{
		new_mu.clear();
		new_mix.clear();
		new_sig.clear();

		Vec N_k(K);
		for(int n=0;n<N;n++)
		{
			for (int k = 0; k < K; k++)
			{
				N_k[k] += gamma_nk[n][k];
			}
		}

		new_mix = N_k;
		for (int k = 0; k < K; k++)
		{
			new_mix[k] /= N;
		}

		new_mu = mul(transpose(gamma_nk),xx);
		for (int k = 0; k < K; k++)
		{
			new_mu[k] = mulScalar(new_mu[k], 1.0/N_k[k]);
		}

		new_sig = vector < Mtx >(K, Mtx(D, Vec(D)));
		for(int k=0;k<K;++k)
		{
			for(int n=0;n<N;++n)
			{
				Vec x_newmu(xx[n]);	// x - new_mu
				for (int i = 0; i < SZ(x_newmu); i++)
				{
					x_newmu[i] -= new_mu[k][i];
				}

				// D=2きめうち
				new_sig[k][0][0] += gamma_nk[n][k] * x_newmu[0] * x_newmu[0];
				new_sig[k][0][1] += gamma_nk[n][k] * x_newmu[0] * x_newmu[1];
				new_sig[k][1][0] += gamma_nk[n][k] * x_newmu[1] * x_newmu[0];
				new_sig[k][1][1] += gamma_nk[n][k] * x_newmu[1] * x_newmu[1];
			}

			new_sig[k] = mulScalar(new_sig[k], 1.0/ N_k[k]);
		}
	}

	// おおよその標準正規乱数(平均0、分散1)
	double getNormRand() const
	{
		double ret = 0.0;
		for( int i = 0; i < 12;i++ ){
			ret += (double)rand()/RAND_MAX;
		}
		return ret-6.0;
	}
};

int main()
{
	EMGaussianMixtures	*em = new EMGaussianMixtures();

	// 入力
	Mtx dataFrame;
	Vec pos(2);
	while (scanf("%lf %lf",&pos[0],&pos[1])!=EOF)
	{
		dataFrame.push_back(pos);
	}

	// 初期化
	const int numLoops	 = 20;
	const int numClasses =  3;
	em->initialize(dataFrame, numClasses);

	// 実行
	for (int i = 0; i < numLoops; i++)
	{
		em->run();
	}

	// 出力
	em->printGammaNK();

	delete em;

	return 0;
}