混合ガウス分布のEMアルゴリズム、C++実装例(2次元のみ)
先日のTopCoderマラソンマッチ 「Octave Classifier(カテゴリー分けする機械学習系の問題)」で、使わずじまいに終わったC++のコードをアップします。
混合ガウス分布のEMアルゴリズムについて、全く知らなかったので、
- EM アルゴリズム実装(勉強用) - Mi manca qualche giovedi`? Nakatani Shuyoさん
- 混合モデルとEMアルゴリズム 山中高夫さん
- PRML本 第9章
を参考にしました。ありがとうございます。基本的に、Nakatani Shuyoさんのコードの移植ですので、比較すると分かりやすいかもしれません。ただ数学関数は自前になってます。2次元限定ですので、ご了承ください。
Rやpython等で実装している方が既にいるので、C++での需要はないかもしれませんが、
- 型が分かりやすい
- ライブラリがなくても、単独で動く。
というメリットはあるので、もしかしたら少しは参考になるかもしれません。
実際にideone.comで実行した結果はhttps://ideone.com/LXH8viです。入力データセットは、Rに付属しているfaithfulというものです(Nakatani Shuyoさんのページと同じです)。3クラスで実行しました。ideoneの出力データをそのままRでplotしたものが、最初の画像になります。どうみても2クラスに分けるためのデータセットですが、テストということで。
#include <iostream> #include <vector> #include <cassert> #include <cmath> #include <numeric> #include <stdio.h> #include <stdlib.h> using namespace std; #define SZ(a) ((int)((a).size())) // 注:データの次元D=2のときしか対応していません! typedef vector <double> Vec; typedef vector <Vec> Mtx; class EMGaussianMixtures { private: int N; // データ個数 int D; // データの次元 int K; // クラスの種類数 Mtx m_xx; // x:正規化したデータ m_xx[N][D] Mtx m_mu; // μ:平均 m_mu[K][D] Vec m_mix; // π:混合率 m_mix[K] vector < Mtx > m_sig; // Σ:共分散 m_sig[K][D][D] Mtx m_gamma_nk; // γ:負担率 m_gamma_nk[N][K] public: // 初期化 // データセット dataFrame[N][D] // クラスの種類数 numClasses -> K void initialize(const Mtx& dataFrame, int numClasses) { N = SZ(dataFrame); D = SZ(dataFrame[0]); K = numClasses; assert(N>=1); assert(D==2); assert(K>=1); // データの正規化 m_xx = dataFrame; normalize(m_xx); // 平均、共分散、混合率の初期値(正規乱数) m_mu = Mtx(K, Vec(D)); for(int row=0;row<K;++row) { for(int col=0;col<D;++col) { m_mu[row][col]= getNormRand(); // たぶん、必須ではないけど、元のコードに合わせました。 } } m_mix = Vec(K, 1.0/K); m_sig = vector < Mtx >(K, Mtx(D, Vec(D))); for (int k = 0; k < K; ++k) { for (int rc=0; rc<D; ++rc) { m_sig[k][rc][rc]=1.0; } } } // E stepとM stepを1回ずつ実行 void run() { m_gamma_nk = calcEstep(m_xx, m_mu, m_mix, m_sig); Mtx new_mu; Vec new_mix; vector < Mtx > new_sig; calcMstep(new_mu, new_mix, new_sig, m_xx, m_gamma_nk); m_mu = new_mu; m_mix = new_mix; m_sig = new_sig; } void printGammaNK() const { for(int n=0;n<N;n++) { for(int k=0;k<K;k++) { printf("%.10f",m_gamma_nk[n][k]); if(k<K-1) { printf(", "); } } printf("\n"); } } private: // 正規化。平均0、不偏標準偏差1にする。 void normalize( Mtx& xx ) const { for (int col=0;col<D;++col) { double mean = 0.0; // 平均 for (int row=0;row<N;++row) { mean += xx[row][col]; } mean /= N; double sd = 0.0; // 不偏標準偏差 for (int row=0;row<N;++row) { const double diff = xx[row][col]-mean; sd += diff * diff; } sd /= (N-1); // 標準偏差じゃなくて不偏標準偏差なので、NじゃなくてN-1で割るのに注意。 sd = sqrt(sd); // 平均を引いた後に、不偏標準偏差で割る。 for (int row=0;row<N;++row) { xx[row][col] = (xx[row][col]-mean)/sd; } } } // 行列と行列の積 inline Mtx mul(const Mtx& a, const Mtx& b) const { vector <vector <double> > c(SZ(a), vector<double>(SZ(b[0]))); for (int i = 0; i < SZ(a); i++) { for (int k = 0; k < SZ(b); k++) { for (int j = 0; j < SZ(b[0]); j++) { c[i][j] += a[i][k]*b[k][j]; } } } return c; } // スカラーと行列の積 inline Mtx mulScalar(const Mtx& a, const double scalar) const { Mtx ret(a); for (int i = 0; i < SZ(ret); i++) { for (int k = 0; k < SZ(ret[0]); k++) { ret[i][k] *= scalar; } } return ret; } // スカラーとベクトルの積 inline Vec mulScalar(const Vec& a, const double scalar) const { Vec ret(a); for (int i = 0; i < SZ(ret); i++) { ret[i] *= scalar; } return ret; } // 行列の転置 inline Mtx transpose( const Mtx& vs ) const { const int H = SZ(vs); const int W = SZ(vs[0]); Mtx ret(W, Vec(H) ); for (int y = 0; y < W; y++) { for (int x = 0; x < H; x++) { ret[y][x] = vs[x][y]; } } return ret; } // 2*2の行列式を求める inline double det(const Mtx& m) const { return m[0][0]*m[1][1]-m[0][1]*m[1][0]; } // 2*2の逆行列を求める inline Mtx solve(const Mtx& m ) const { vector < vector <double> > ret(m); swap(ret[0][0],ret[1][1]); ret[0][1] = -ret[0][1]; ret[1][0] = -ret[1][0]; ret = mulScalar(ret,1.0/abs(det(m))); return ret; } // 多次元正規分布密度関数を求める double getDMNorm(const Vec& x, const Vec& mu, const Mtx& sig) const { Vec x_mu(x); // x - mu for (int i = 0; i < SZ(x_mu); i++) { x_mu[i] -= mu[i]; } const Mtx inv = solve(sig); // D=2決め打ちなので、mul(mul(transpose(x_mu),solve(sig)),x_mu)の部分を展開し、Cとした。 const double C = x_mu[0]*(x_mu[0]*inv[0][0]+x_mu[1]*inv[1][0]) + x_mu[1]*(x_mu[0]*inv[0][1]+x_mu[1]*inv[1][1]); double ret = 1/(sqrt(pow(2.0*M_PI,D) * det(sig))) * exp( -0.5 * C ); return ret; } // Eステップ。返り値は gamma_nk[N][K] Mtx calcEstep( const Mtx& xx, const Mtx& mu, const Vec& mix, const vector < Mtx >& sig) const { Mtx ret; for(int row=0;row<N;++row) { Vec numer(K); for(int k=0;k<K;k++) { numer[k]=mix[k]*getDMNorm(xx[row], mu[k], sig[k]); } const double sum_numer = accumulate(numer.begin(),numer.end(),0.0); for (int k=0; k < K; k++) { numer[k]/=sum_numer; } ret.push_back(numer); } return ret; } // Mステップ。次のステップのmu,mix,sigを求める。 void calcMstep( Mtx& new_mu, Vec& new_mix, vector < Mtx >& new_sig, const Mtx& xx, const Mtx& gamma_nk) const { new_mu.clear(); new_mix.clear(); new_sig.clear(); Vec N_k(K); for(int n=0;n<N;n++) { for (int k = 0; k < K; k++) { N_k[k] += gamma_nk[n][k]; } } new_mix = N_k; for (int k = 0; k < K; k++) { new_mix[k] /= N; } new_mu = mul(transpose(gamma_nk),xx); for (int k = 0; k < K; k++) { new_mu[k] = mulScalar(new_mu[k], 1.0/N_k[k]); } new_sig = vector < Mtx >(K, Mtx(D, Vec(D))); for(int k=0;k<K;++k) { for(int n=0;n<N;++n) { Vec x_newmu(xx[n]); // x - new_mu for (int i = 0; i < SZ(x_newmu); i++) { x_newmu[i] -= new_mu[k][i]; } // D=2きめうち new_sig[k][0][0] += gamma_nk[n][k] * x_newmu[0] * x_newmu[0]; new_sig[k][0][1] += gamma_nk[n][k] * x_newmu[0] * x_newmu[1]; new_sig[k][1][0] += gamma_nk[n][k] * x_newmu[1] * x_newmu[0]; new_sig[k][1][1] += gamma_nk[n][k] * x_newmu[1] * x_newmu[1]; } new_sig[k] = mulScalar(new_sig[k], 1.0/ N_k[k]); } } // おおよその標準正規乱数(平均0、分散1) double getNormRand() const { double ret = 0.0; for( int i = 0; i < 12;i++ ){ ret += (double)rand()/RAND_MAX; } return ret-6.0; } }; int main() { EMGaussianMixtures *em = new EMGaussianMixtures(); // 入力 Mtx dataFrame; Vec pos(2); while (scanf("%lf %lf",&pos[0],&pos[1])!=EOF) { dataFrame.push_back(pos); } // 初期化 const int numLoops = 20; const int numClasses = 3; em->initialize(dataFrame, numClasses); // 実行 for (int i = 0; i < numLoops; i++) { em->run(); } // 出力 em->printGammaNK(); delete em; return 0; }