Negative/Positive Thinking

2013-10-05

逐次確率比検定を試す

はじめに

あらかじめ標本サイズを決めるのではなく、十分と判断されるまでダイナミックに判断を繰り返す逐次確率比検定を参考に、
チョコボールの銀のエンジェルの出現確率について判断するとどうなるか試してみる。

逐次確率比検定とは

  • ベイズ統計学の枠組みで、ベイズ更新の機能を通して1つずつ標本抽出していきながら同時に検定にも用いる事ができる
  • 逐次決定過程 : 標本抽出をするたびに判断を行い、結論がでたと認められるタイミングで停止する過程
  • 行動
    • action0 : 結論を保留し、標本抽出を再度行う
    • action1 : 帰無仮説H1を採択
    • action2 : 対立仮説H2を採択

尤度比検定(Likelihood Ratio Test)
  • 「尤度比」を検定統計量として行う統計学的検定の総称
    • 尤度比λ=(Π^n_{i=1}{f(Xi|θ1}) / (Π^n_{i=1}{f(Xi|θ2})
    • 帰無仮説H1 : θ=θ1
    • 対率仮説H2 : θ=θ2
  • 尤度比検定 : 任意のC>0に対して、
    • λ<=Cならば、H1を採択
    • λ>Cならば、H2を採択

逐次確率比検定(Sequential Probability Ratio Test)
  • あるA,B>0に対して、n回目まで抽出した標本での尤度比λnに対し、以下のようにする
    • B<λn<A : action0
    • λn <= B : action1
    • A <= λn : action2
    • B = β / (1 - α)
    • A = (1 - β) / α
      • α:第1種の誤りの確率(有意水準)
      • β:第2種の誤りの確率(危険率,1-β=検出力)

二項分布の場合
  • n回の独立試行のうち、k回が事象Tで、(n-k)回が事象Sで、各試行における事象Tである確率pは一定であるような場合、事象Tは二項分布となる
    • P[X=k;p] = nCk * p^k * (1-p)^{n-k}
  • H1 : p = θ1 (出現確率pがθ1より小さい)
  • H2 : p = θ2 (出現確率pがθ2より大きい)
  • 尤度比λn = {(1-θ2)/(1-θ1)}^n * {(θ2/(1-θ2))/(θ1/(1-θ1))}^k
    • β/(1-α) < λn < (1-β)/αである限り抽出を繰り返す


チョコボールのエンジェル

チョコボールのエンジェルとは
チョコボールのエンジェル出現確率の変動
購入判断
  • 銀のエンジェルがでる確率が5%以上ならばチョコボールを買い続けるけど、1%以下ならば買い続けるのはバカらしいのでやめたい、とする
  • 判断方法として、逐次確率比検定の二項分布の場合に照らし合わせて考えてみる
    • 常にエンジェルが入っている確率は、抽出期間内において、時期に依らず一定であると仮定

チョコボールの標本抽出方法
  • チョコボールの標本抽出には以下があると考えられる
    • 母集団を販売されているチョコボールとして、定期的に1つずつ無作為に選んで購入する(無作為抽出)
    • ある店舗において在庫としておいてある箱を複数個無作為に選んで購入し、一つずつ抽出する
    • 店舗、箱、と段階に分けて無作為抽出を繰り返す
    • など

定義
  • 以下と定義して計算する
    • H1 : p = 0.1
    • H2 : p = 0.5
  • チョコボールの値段は安いので、標本サイズが増えたとしても、厳しめに判定したいとする
    • α = 0.05
    • β = 0.005

コード

シミュレーションしてみる。

#include <iostream>
#include <vector>
#include <cmath>

static const double PI = 3.141592653589793238;
static const double EPS = 1e-9;

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割った方がよい
double frand(){
  return xor128()%10000000000/static_cast<double>(10000000000); 
}


bool getChocoBall(double p){
  double r = frand();
  if(r < p) return true;
  return false;
}


int main(){
  int n = 0, k = 0;
  double p = 0.005; //実際の出現確率
  
  double theta1 = 0.01; //H1
  double theta2 = 0.05; //H2
  double alpha = 0.05; //有意水準
  double beta = 0.005; //危険度

  while(true){
    n++;

    //購入
    bool res = getChocoBall(p);
    if(res) k++;

    //対数尤度比の計算
    double cal = 0.0;
    cal += n * log( (1.0-theta2)/(1.0-theta1) );
    cal += k * log( (theta2/(1.0-theta2)) / (theta1/(1.0-theta1)) );

    //判定
    double B = log( beta / (1.0-alpha) );
    double A = log( (1.0-beta) / alpha );
    //std::cout << B << "\t" << cal << "\t" << A << std::endl;
    if(cal+EPS < B){
      std::cout << "STOP" << std::endl; //買うのをやめたい
      break;
    }
    else if(A+EPS < cal){
      std::cout << "CONTINUE" << std::endl; //これからも買い続けよう
      break;
    }

    std::cout << n << "(" << k << ") : MEASURING..." << std::endl;
  }
  return 0;

}

結果

1000回ずつ、買い続けるか(CONTINUE)、買うのをやめるか(STOP)を判定したものを比較してみる。



実際の確率STOPになった個数CONTINUEになった個数抽出回数についての標本平均抽出回数についての標本分散
p=0.00110000179.05428.654
p=0.0059982220.8043607.37
p=0.0196634302.27619772.7
p=0.02537463449.92128463
p=0.0363937255.58652570.5
p=0.045995130.00912308.3
p=0.05199987.4345358.39
p=0.100100032.315436.67

参考

2013-04-20

時系列解析メモ

はじめに

時系列解析について、簡単にメモ。

時系列(time series)とは

  • 時間の経過で変動する何かの数値の列
    • 例:気象データ、株価、など
  • 時系列解析は、このデータを統計解析すること

時系列の分類
  • 連続時間・離散時間
    • 時間間隔が連続的か、離散的(1時間おき、とか)か
  • 1変量・多変量
    • 1つの情報だけか、同じ時間の2つ以上の情報が与えられるか
  • 定常・非定常
    • 時間的に変化しない確率モデルの実現値とみなせる(定常)か、そうでないか
      • 弱定常: 分布がl時間(シフト)後と同じになるようなもの
      • 強定常: 分布が時間(シフト)に対して不変になるようなもの
  • ガウス型・非ガウス型
    • 時系列の分布が正規分布に従うか、そうでないか
  • 線形・非線形
    • 線形なモデルの出力として表現できるか、そうでないか

時系列解析の目的

以下のようなことをやりたい。

  • description
    • 図示や時系列の特徴を簡潔に表現
  • modeling
    • 時系列の変動を表現するモデルの構成・解析、パラメータ推定
  • prediction
  • extraction
    • 必要な情報の抽出

非定常性のあるデータの扱い

前処理をするか、直接扱うか。

前処理
  • 変数変換
    • 対数変換(値の増加で変動の大きさも増えるような場合)
      • z_n = log(y_n)
    • Box-Cox変換
      • z_n = λ^{-1} ({y_n}^λ -1) (λ≠0のとき)
      • z_n = log(y_n) (λ=0のとき)
      • (パラメータλの推定は、AICを使うなどして推定)
    • ロジット変換(確率や割合のような(0,1)の範囲を(-∞,∞)の範囲にする)
      • z_n = log(y_n/(1-y_n))
  • 差分
    • 差分(トレンド、傾きを含むような場合、それらを取り除くことができる)
      • z_n = y_n - y_{n-1}
    • 1周期分の差分(月周期・年周期性を持つ場合など)
      • z_n = y_n - y_{n-p}
    • 対数値の差分(volatility分析などで用いられる)
      • z_n = log(y_n) - log(y_{n-1})
  • 前年比・前年同期比
    • 前年比(yがトレンド*ノイズのような場合、それぞれを分離できる。トレンド:長期的傾向変動)
      • z_n = y_n / y_{n-1}
    • 前年同期比
      • z_n = y_n / y_{n-p}
  • 移動平均
    • 移動平均(変動を滑らかにできる)
      • z_n = 1/(2k+1) * Σ_{j=-k}^{k} y_{n+j}
    • 重み付き移動平均
      • z_n = Σ_{j=-k}^{k} ω_j * y_{n+j}
      • Σ_{j=-k}^{k} ω_j=1, ω_j>=1
    • 移動メディアン
      • z_n = median{y_{n-k},...,y_n,...,y_{n+k}}

直接扱う

各非定常性を考慮した時系列モデルを使う。

定常時系列の自己相関関数

  • 自己共分散
    • 時系列の変動の特徴をとらえるために、y_nとy_{n-k}との共分散を考える
    • Cov(y_n, y_{n-k})
  • 自己共分散関数
    • C_k = Cov(y_n, y_{n-k})
    • k: ラグ
  • 自己相関関数
    • R_k = C_k / C_0
      • 定常分布を仮定すると、y_nとy_{n-k}の相関係数において、それぞれの分散がC_0なので

サンプリング自己相関関数

定常時系列{y_1,...,y_n}が与えられたとき、

  • 標本平均 = 1/n * Σy
  • 標本自己共分散関数 = 1/n * Σ(y_n-標本平均)*(y_{n-k}-標本平均)
  • 標本自己相関関数 = 標本自己共分散関数_k / 標本自己共分散関数_0

自己相関をグラフ(コレログラム)にプロットすることで、変動の特徴(周期性やトレンドなど)も見ることができる
(定常分布ならkの増加で急速に0に収束する)

時系列モデル

以下のようなモデルが研究されている。

ARモデル
  • m次の自己回帰モデル
    • 過去の値とノイズだけで表現
    • y_n = Σ_{i=1}^{m} a_i * y_{n-i} + v_n
      • v_n: 白色ノイズ

  • パラメータ推定法
    • Yule-Walker法
      • Levinson's Algorithm
    • 最小二乗法
    • PARCOR(偏自己相関係数,partial autocorrelation coefficient)法

多変量ARモデル
  • ARモデルを多変量時系列に拡張したもの

局所定常ARモデル
  • 非定常な時系列について、時間を小区間に分割しそれぞれの区間内では定常と仮定する

時変係数ARモデル
  • 変動の仕方が時間とともに変化する非定常時系列に対して、係数が時間とともに変化するARモデル
    • y_n = Σ_{j=i}^{m} a_{ni} * y_{n-j} + v_n

MAモデル
ARMAモデル
ARIMAモデル
  • 自己回帰和分移動平均モデル
  • 平均値などが時間によって変動する場合などはよくあるので、時系列の差分を考えARMAモデルを適用したもの
  • 非定常な時系列の一つを扱える

SARIMAモデル
  • 季節性(z_n==z_{n-p}を近似的に満たす)を考えたARIMAモデル

ARCHモデル
  • 分散自己回帰モデル
  • 分散不均一性(変動率が時期によって異なった水準を示す)を示すデータへ対応

GARCHモデル
  • 一般化ARCHモデル

その他の話題

参考

2012-03-17

ノンパラベイズな言語モデルを試す

はじめに

最近「言語モデル」がマイブームなので、最近有名になりつつあるというMCMC法を使ったベイズな言語モデルとして、「階層的Pitman-Yor言語モデル(HPYLM)」を試しにちょっと作ってみた。

とりあえず、文字bigramのHPYLMを試してみる。
毎度のことながら勉強用、実験用なのででかいデータはちょっとまずいと思う。。。

(追記3/17 22:30)
コードで最終的な値に使う所がおかしいのを修正とdとθの収束についてを追加

コード

#include <iostream>
#include <sstream>
#include <fstream>
#include <vector>
#include <deque>
#include <map>
#include <algorithm>
#include <string>
#include <cmath>
#include <climits>

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); 
}
//Bernoulli試行(確率pで1、1-pで0を返す)
double bernoulli_rand(double p){
  double r = frand();
  if(r<p) return 1.0;
  return 0.0;
}
//gamma分布に従う乱数
double gamma_rand(double shape, double scale){
  double n, b1, b2, c1, c2;
  if(4.0 < shape) n = 1.0/sqrt(shape);
  else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6;
  else if(0.0 < shape) n = 1.0/shape;
  else return -1;

  b1 = shape - 1.0/n;
  b2 = shape + 1.0/n;

  if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0;
  else c1 = 0;
  c2 = b2 * (log(b2)-1.0) / 2.0;

  while(true){
    double v1 = frand(), v2 = frand();
    double w1, w2, y, x;
    w1 = c1 + log(v1);
    w2 = c2 + log(v2);
    y = n * (b1*w2-b2*w1);
    if(y < 0) continue;
    x = n * (w2-w1);
    if(log(y) < x) continue;
    return exp(x) * scale;
  }
  return -1;
}
//beta分布に従う乱数
double beta_rand(double a, double b){
  double gamma1 = gamma_rand(a,1.0);
  double gamma2 = gamma_rand(b,1.0);
  return gamma1/(gamma1+gamma2);
}
//離散確率でインデックスを返す
int selectProb(const std::vector< std::pair<int,double> > &p){
  double sum = 0.0;
  for(int i=0; i<(int)p.size(); i++) sum += p[i].second;
  double r = frand()*sum, q = 0;
  for(int i=0; i<(int)p.size()-1; i++){
    q += p[i].second;
    if(r<q) return p[i].first;
  }
  return p[p.size()-1].first;
}


//各ngramのレストラン
class Restaurant {
  bool base_flag; //特殊なレストラン(u="")
  double theta, d; //parameters
public:
  std::deque< std::pair<std::string,int> > tables; //各テーブル(料理と人数)

  Restaurant(bool base_flag=false):base_flag(base_flag){
    d = 0.6; //適当
    theta = 1; //適当
  }

  int c_uwd(const std::string &w){
    int cnt = 0;
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){
	cnt += tables[i].second;
      }
    }
    return cnt;
  }
  double d_u(){ return d; }
  int t_uw(const std::string &w){ 
    int cnt = 0;
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){
	cnt++;
      }
    }
    return cnt;
  }
  double theta_u(){ return theta; }
  int c_udd(){
    int cnt = 0;
    for(int i=0; i<(int)tables.size(); i++) cnt += tables[i].second;
    return cnt;
  }
  int t_ud(){ return tables.size(); }
  
  //AddCustomer
  bool AddCustomer(const std::string &w, const double &pw){
    std::vector< std::pair<int,double> > p; //wを持つ各テーブルのインデクスと人数
    //既存のテーブル
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){
	p.push_back(std::make_pair(i,std::max(0.0,tables[i].second - d_u())));
      }
    }
    //未知のテーブル
    p.push_back(std::make_pair(tables.size(),(theta_u() + d_u() * t_ud()) * pw));
    int k = selectProb(p);
    if(k < (int)tables.size()){ //既知のテーブルに座る
      tables[k].second++;
    }else{ //未知のテーブルに座る
      tables.push_back(std::make_pair(w,1));
      if(base_flag) return false;
      return true;
    }
    return false;
  }
  //RemoveCustomer
  bool RemoveCustomer(const std::string &w){
    std::vector< std::pair<int,double> > p; //wを持つ各テーブルのインデクスと人数
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){ 
	p.push_back(std::make_pair(i,tables[i].second));
      }
    }
    int k = selectProb(p);
    tables[k].second--;
    if(tables[k].second == 0){
      tables.erase(tables.begin() + k);
      if(base_flag) return false;
      return true;
    }
    return false;
  }
  //パラメータの更新
  void UpdateParameters(double d_, double theta_){
    d = d_;
    theta = theta_;
  }
};

//階層Pitman-Yor言語モデル
class HPYLM {
  int n_val; //n-gram modelのnの値
  int num_word_type; //文字種の数
  std::map<std::deque<std::string>,Restaurant> ngrams; //各レストラン
  std::vector<double> a_m, b_m, alpha_m, beta_m, d_m, theta_m; //各ngramで共通のハイパーパラメータ

  //ハイパーパラメータの履歴を保存しておく用
  std::vector< std::vector<double> > history_d_m, history_theta_m;

  //基底測度G_0(w)
  double WordProbBase(const std::string &w){
    return 1.0/num_word_type;
  }
public:
  //ngramのnの値, G0用の文字種の数
  HPYLM(int n, int nwt){
    n_val = n;
    num_word_type = nwt;
    for(int i=0; i<n_val; i++){
      a_m.push_back(1.0);
      b_m.push_back(1.0);
      alpha_m.push_back(1.0);
      beta_m.push_back(1.0);
      d_m.push_back(0.6); //適当
      theta_m.push_back(0.1); //適当

      history_d_m.push_back(std::vector<double>());
      history_theta_m.push_back(std::vector<double>());
    }
    ngrams.insert(std::make_pair(std::deque<std::string>(), Restaurant(true)));
  }
  void show_ngrams(){
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      std::cout << "Restaurant[";
      for(int i=0; i<(int)(itr->first.size()); i++){
	std::cout << itr->first[i];
	if(i!=(int)(itr->first.size()-1)) std::cout << "|";
      }
      std::cout << "]" << "(d=" << itr->second.d_u() << ",theta=" << itr->second.theta_u() << ")" << std::endl;
      for(int i=0; i<(int)(itr->second.tables.size()); i++){
	std::cout << itr->second.tables[i].first << "\t" << itr->second.tables[i].second << std::endl;
      }
    }
  }
  //単語の出現確率
  double WordProbability(std::deque<std::string> u, const std::string &w){
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    std::pair<std::map<std::deque<std::string>,Restaurant>::iterator,bool> res;
    res = ngrams.insert(std::make_pair(u,Restaurant()));
    itr = res.first;
    double c_uwd = itr->second.c_uwd(w);
    double d_u = itr->second.d_u();
    double t_uw = itr->second.t_uw(w);
    double theta_u = itr->second.theta_u();
    double c_udd = itr->second.c_udd();
    double t_ud = itr->second.t_ud();
    if(u.size() == 0) return ((c_uwd - d_u * t_uw) + (theta_u + d_u * t_ud) * WordProbBase(w))/(theta_u + c_udd);
    u.pop_front();
    return ((c_uwd - d_u * t_uw) + (theta_u + d_u * t_ud) * WordProbability(u,w))/(theta_u + c_udd);
  }
  //単語の追加
  void AddCustomer(std::deque<std::string> u, const std::string &w){
    std::deque<std::string> pu(u);
    pu.pop_front();
    while(ngrams[u].AddCustomer(w,WordProbability(pu,w))){
      u.pop_front();
      if(pu.size()>0) pu.pop_front();
    }
  }
  //単語の削除
  void RemoveCustomer(std::deque<std::string> u, const std::string &w){
    while(ngrams[u].RemoveCustomer(w)){
      u.pop_front();
    }
  }
  //ハイパーパラメータの更新
  void UpdateHyperParams(){
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      int m = itr->first.size();
      int t_ud = itr->second.t_ud();
      if(t_ud >= 2){
	double tmp_am = 0, tmp_alpham = 0;
	for(int i=1; i<=t_ud-1; i++){
	  double yui = bernoulli_rand(theta_m[m]/(theta_m[m]+i*d_m[m]));
	  tmp_am += 1.0 - yui;
	  tmp_alpham += yui;
	}
	a_m[m] += tmp_am;
	alpha_m[m] += tmp_alpham;
	double xu = beta_rand(theta_m[m]+1.0, itr->second.c_udd()-1.0);
	beta_m[m] -= log(xu);
      }
      for(int i=0; i<(int)itr->second.tables.size(); i++){
	if(itr->second.tables[i].second >= 2){
	  double tmp_bm = 0;
	  for(int j=1; j<=(int)itr->second.tables[i].second-1; j++){
	    double zuwkj = bernoulli_rand((j-1.0)/(j-d_m[m]));
	    tmp_bm += 1.0 - zuwkj;
	  }
	  b_m[m] += tmp_bm;
	}
      }
    }
    //dとthetaの更新
    for(int m=0; m<n_val; m++){
      d_m[m] = beta_rand(a_m[m], b_m[m]);
      theta_m[m] = gamma_rand(alpha_m[m], 1.0/beta_m[m]);
    }
    //各レストランに反映
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      int m = itr->first.size();
      itr->second.UpdateParameters(d_m[m],theta_m[m]);
    }
  }

  //モデルの推定(Blocked Gibbs Sampler)
  void inference(const std::string &filename, int t_iteration, int t_burnin){
    for(int t=0; t<t_iteration; t++){
      if(t%1000==0) std::cout << "." << std::flush;
      std::ifstream ifs(filename.c_str()); //毎回ファイルから読み込むorz
      std::string text;
      //ngram読み込み
      while(std::getline(ifs, text)){
        std::stringstream ss(text);
        std::deque<std::string> u;
        std::string tmp, w;
        for(int i=0; i<n_val-1; i++){
          ss >> tmp;
          u.push_back(tmp);
        }
        ss >> w;

        if(t > 0) RemoveCustomer(u, w);
        AddCustomer(u, w);
      }
      //ハイパーパラメータの更新
      UpdateHyperParams();
      
      if(t >= t_burnin){//パラメータの値を保存しておく
        for(int i=0; i<n_val; i++){
          history_d_m[i].push_back(d_m[i]);
          history_theta_m[i].push_back(theta_m[i]);
        }
      }
    }
    std::cout << std::endl;

    /*//確認用dとthetaの出力
    for(int i=0; i<t_iteration-t_burnin; i++){
      std::cerr << i << "\t";
      for(int m=0; m<n_val; m++){
        std::cerr << history_d_m[m][i] << "\t" << history_theta_m[m][i] << "\t";        
      }
      std::cerr << std::endl;
    }
    */
    
    //burnin期間を除いたパラメータの平均値で最終的なd_mとtheta_mを決めてみる
    std::vector<double> ave_d_m(n_val), ave_theta_m(n_val);
    for(int m=0; m<n_val; m++){
      int num = history_d_m[m].size();
      for(int i=0; i<num; i++){
        ave_d_m[m] += history_d_m[m][i];
        ave_theta_m[m] += history_theta_m[m][i];
      }
      ave_d_m[m] /= num;
      ave_theta_m[m] /= num;
      //平均値を使う
      d_m[m] = ave_d_m[m];
      theta_m[m] = ave_theta_m[m];
    }
    //全てのレストランに反映
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      int m = itr->first.size();
      itr->second.UpdateParameters(d_m[m],theta_m[m]);
    }    
  }
};


int main(){
  int n = 2; //bigram言語モデル
  int n_char = 10; //世界には10文字ぐらいあると仮定
  HPYLM hpylm(n, n_char);
  hpylm.inference("ngram.data", 100000, 20000); //100000回反復(うち最初の20000回はburninとして捨てる)

  //モデルの状態を表示
  std::cout << "======================" << std::endl;
  hpylm.show_ngrams();
  std::cout << "======================" << std::endl;

  //各ngramの確率を求める
  std::string text;
  while(std::getline(std::cin, text)){
    std::stringstream ss(text);
    std::deque<std::string> u;
    std::string tmp, w;
    for(int i=0; i<n-1; i++){
      ss >> tmp;
      u.push_back(tmp);
    }
    ss >> w;
    
    std::cout << "Prob : " << hpylm.WordProbability(u, w) << std::endl;
  }

  return 0;
}

結果

学習データ「ngram.data」
  • 半角スペース区切り
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
明 日
明 日
明 日
明 日
明 日
実行結果
$ ./a.out
................................................................................
....................
======================
Restaurant[](d=0.553945,theta=0.730763)
月      2
日      1
日      2
日      1
Restaurant[今](d=0.0573542,theta=0.595518)
日      10
月      8
月      12
Restaurant[明](d=0.0573542,theta=0.595518)
日      3
日      1
日      1
======================
今 日
Prob : 0.334784
今 月
Prob : 0.65643
今 年
Prob : 0.00109828
明 日
Prob : 0.916481
明 月
Prob : 0.0354769
明 年
Prob : 0.00600527
あ い
Prob : 0.0437773

P(日|今)は、最尤推定だと0.333333ぐらい
P(月|今)は、最尤推定だと0.666666ぐらい
P(年|今)は、最尤推定だと0.0
P(日|明)は、最尤推定だと1.0
P(月|明)は、最尤推定だと0.0
P(年|明)は、最尤推定だと0.0
P(い|あ)は、最尤推定だと0.0

おぉ、、そこそこそれっぽい結果がでてるようにみえる。。。
学習データ中には「明月」はないけど、「今月」が出てる関係で、他の学習データにないものに比べちょっと確率が高くできてる。
「あい」は未知の語同士だからわからないはずだけど、ちゃんと確率値が与えられてる。

dとθの収束について

モデルが収束してるのかdとθの変化についてプロットしてみる。10万回分。

f:id:jetbead:20120317222834p:image

赤線:d_0
緑線:theta_0
青線:d_1
ピンク線:theta_1

まだバグがあるのかもしれないけど、えらい収束がゆっくりしてるように見える。
(200回程度じゃ全然収束してない、、、ので反復回数を修正)
あと、最終的な値に使うのが生成した値じゃなかったので、コードの方をdとtheta値を保存して平均をとるように変更。
うーん、、難しい。

参考資料

2012-03-03

ノンパラの実験

はじめに

ノンパラメトリックに関する論文やチュートリアルを見ていると「混合正規分布の混合数も推定できちゃうんですよぱねぇ」みたいなこと書いてあったり、その図が載ってたりする。
試してみたいなぁと思ったので、書いて実験してみる。

理解が乏しいのでやり方が間違っている可能性が高く、コーディングがスマートじゃない、あくまで実験用。。。orz

説明

CRP-type samplingでやってみる。
あるいくつかの正規分布に従うデータが与えられるとき、その混合比πとパラメータθを推定する。


P(x|θ) : 正規分布
G0 : パラメータの事前分布(平均の事前分布、分散の事前分布)
θ_k : テーブルkのパラメータ(平均と分散)
α_0 : 集中度パラメータ
π : 混合数と混合比

→G0(のパラメータ)とα_0とデータを与える。


データは、1次元のを自分で用意して使う。以下は、各正規分布が離れているちょっとずるいデータ。。
N(-50,10)に従うデータを200個、N(0,10)に従うデータを300個、N(50,10)に従うデータを500個。


推定は、Neal2000のAlgorithm8を利用してみる。新しいテーブル数Mは1つにする。(というか、他のAlgorithmの積分の計算がよくわからない、、、)
パラメータの更新は、先日(http://d.hatena.ne.jp/jetbead/20120229/1330535260)の更新式を使う。

出力は、各テーブル(正規分布)に所属するのデータをそれぞれ出力のみ。(離れているデータを使うので、平均・最頻を出さないで、最終的な結果だけで、、、)

コード

#include <iostream>
#include <vector>
#include <set>
#include <algorithm>
#include <cmath>
static const double PI = 3.14159265358979323846264338;

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); 
}
//ガンマ乱数
double gamma_rand(double shape, double scale){
  double n, b1, b2, c1, c2;
  if(4.0 < shape) n = 1.0/sqrt(shape);
  else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6;
  else if(0.0 < shape) n = 1.0/shape;
  else return -1;

  b1 = shape - 1.0/n;
  b2 = shape + 1.0/n;

  if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0;
  else c1 = 0;
  c2 = b2 * (log(b2)-1.0) / 2.0;

  while(true){
    double v1 = frand(), v2 = frand();
    double w1, w2, y, x;
    w1 = c1 + log(v1);
    w2 = c2 + log(v2);
    y = n * (b1*w2-b2*w1);
    if(y < 0) continue;
    x = n * (w2-w1);
    if(log(y) < x) continue;
    return exp(x) * scale;
  }
  return -1;
}
//逆ガンマ乱数
double invgamma_rand(double shape, double scale){
  return 1.0/gamma_rand(shape, scale);
}
//正規乱数
double normal_rand(double mu, double sigma2){
  double sigma = sqrt(sigma2);
  double u1 = frand(), u2 = frand();
  double z1 = sqrt(-2*log(u1)) * cos(2*PI*u2);
  //double z2 = sqrt(-2*log(u1)) * sin(2*PI*u2);
  return mu + sigma*z1;
}
//離散分布に従う乱数
int prob_select(const std::vector<double> &p){
  double Z = 0.0, sum = 0.0;
  double r = frand();
  for(int i=0; i<p.size(); i++) Z += p[i];
  for(int i=0; i<p.size()-1; i++){
    sum += p[i]/Z;
    if(r<sum) return i;
  }
  return p.size()-1;
}



//正規分布
struct Gaussian {
  double mu, sigma2;
  double G0_mean, G0_sigma2, G0_shape, G0_scale;
  std::vector<double> _data;
  std::set<int> data_idx;
  bool isExist(int idx){ return data_idx.count(idx); }
  void eraseIdx(int idx){ data_idx.erase(idx); }
  void paramUpdate(){
    double n_0 = G0_shape * 2.0;
    double S_0 = G0_scale * 2.0 / n_0;
    int n = data_idx.size();
    double x_bar = 0.0;
    for(std::set<int>::iterator itr=data_idx.begin(); itr != data_idx.end(); itr++){
      x_bar += _data[*itr];
    }
    x_bar /= n;
    double m_1 = 1.0 + n;
    double n_1 = n_0 + n;
    double mu_1 = (1.0 * G0_mean + n * x_bar) / (1.0 + n);
    double nS_1 = n_0 * S_0;
    for(std::set<int>::iterator itr=data_idx.begin(); itr != data_idx.end(); itr++){
      nS_1 += (_data[*itr]-x_bar)*(_data[*itr]-x_bar);
    }
    nS_1 += (1.0 * n) / (1.0 + n) * (x_bar - mu) * (x_bar - mu);
    sigma2 = invgamma_rand((n_1+1)/2.0, 2.0/(nS_1+m_1*(mu-mu_1)*(mu-mu_1)));
    mu = normal_rand(mu_1, sigma2/m_1);
  }
  Gaussian(std::vector<double> &data, double mu, double sigma2, double G0_mean, double G0_sigma2, double G0_shape, double G0_scale):
    _data(data),mu(mu),sigma2(sigma2),G0_mean(G0_mean),G0_sigma2(G0_sigma2),G0_shape(G0_shape),G0_scale(G0_scale){
    data_idx.clear();
  }
};


//Dirichlet Process Gaussian Mixture Model
class DPGMM {

  double alpha_0;
  double G0_mean, G0_sigma2, G0_shape, G0_scale;
  std::vector<double> _data;
  std::vector<Gaussian> _mixture;

public:
  DPGMM(double alpha_0, double G0_mean, double G0_sigma2, double G0_shape, double G0_scale):
    alpha_0(alpha_0), G0_mean(G0_mean), G0_sigma2(G0_sigma2), G0_shape(G0_shape), G0_scale(G0_scale){
  }
  
  //データをセットする
  void set_data(const std::vector<double> &data){
    _data = data;
    _mixture.clear();
    _mixture.push_back(Gaussian(_data, normal_rand(G0_mean,G0_sigma2), invgamma_rand(G0_shape,G0_scale),
				G0_mean, G0_sigma2, G0_shape, G0_scale));
    for(int i=0; i<_data.size(); i++){
      _mixture[0].data_idx.insert(i); //最初は全て同じテーブル
    }
  }
  //推定
  // 注意:burn_in使ってない
  void inference(int iter, int burn_in){
    int M = 1; //テーブル拡張数
    for(int T=0; T<iter; T++){
      //if(T%100==0) std::cout << T << std::endl;

      //お客をリサンプリング
      for(int i=0; i<_data.size(); i++){
	//テーブルを拡張
	for(int j=0; j<M; j++){
	  _mixture.push_back(Gaussian(_data, normal_rand(G0_mean,G0_sigma2), invgamma_rand(G0_shape,G0_scale),
				      G0_mean, G0_sigma2, G0_shape, G0_scale));
	}
	//客iを削除
	int table_i = -1; //客iが座っていたテーブル
	for(int j=0; j<_mixture.size(); j++){
	  if(_mixture[j].isExist(i)){
	    table_i = j;
	    break;
	  }
	}
	_mixture[table_i].eraseIdx(i); //テーブルから削除
	//新たに座るテーブルを選択
	int select_table = -1; //新たに選ぶテーブル
	std::vector<double> p; //そのテーブルを選ぶ確率
	for(int j=0; j<_mixture.size(); j++){
	  double mu_j = _mixture[j].mu;
	  double sigma2_j = _mixture[j].sigma2;
	  double x_i = _data[i];
	  double f_x = exp(-(x_i-mu_j)*(x_i-mu_j)/(2*sigma2_j))/sqrt(2*PI*sigma2_j);
	  if(_mixture[j].data_idx.size()>0){ //既存のテーブル
	    double p_d = (_mixture[j].data_idx.size())/(_data.size()-1+alpha_0);
	    p.push_back(p_d*f_x);
	  }else{ //新しいテーブル
	    double p_d = (alpha_0/M)/(_data.size()-1+alpha_0);
	    p.push_back(p_d*f_x);
	  }
	}
	_mixture[prob_select(p)].data_idx.insert(i); //テーブルに座る

	//誰も座っていないテーブルを削除
	for(std::vector<Gaussian>::iterator itr = _mixture.begin(); itr != _mixture.end();){
	  if(itr->data_idx.size() == 0) itr = _mixture.erase(itr);
	  else ++itr;
	}  
      }
      //各テーブルのパラメータを更新
      for(int i=0; i<_mixture.size(); i++){
	_mixture[i].paramUpdate();
      }
    }
  }
  //結果の出力
  // 注意:最終的な結果のみを表示
  void output(){
    for(int i=0; i<_mixture.size(); i++){
      std::cout << "Table[" << i << "]" << std::endl;
      for(std::set<int>::iterator itr = _mixture[i].data_idx.begin(); itr != _mixture[i].data_idx.end(); itr++){
	std::cout << "     " << (*itr) << "\t" << _data[(*itr)] << std::endl;
      }
    }
  }
};

int main(){
  //1.データの作成
  std::vector<double> data;
  for(int i=0; i<200; i++) data.push_back(normal_rand(-50,10));
  for(int i=0; i<300; i++) data.push_back(normal_rand(0,10));
  for(int i=0; i<500; i++) data.push_back(normal_rand(50,10));

  //2.モデルの推定
  ///集中度パラメータ
  double alpha_0 = 1;
  ///基底測度G_0 for Gaussian Parameter
  double G0_mean = 0;
  double G0_sigma2 = 1000;
  double G0_shape = 0.01;
  double G0_scale = 0.01;

  DPGMM dpgmm(alpha_0, G0_mean, G0_sigma2, G0_shape, G0_scale);

  dpgmm.set_data(data);
  dpgmm.inference(10000, 3000);

  //3.推定結果
  dpgmm.output();

  return 0;
}

結果

f:id:jetbead:20120303142854p:image

iteration回した最後の結果。それっぽく分かれてくれてる。。。
本当は最後だけじゃなく、それまでの情報(平均値とか最頻indexとか)を使うべきだけど、、、

# データのindex  データの値
Table[0]
     2	-35.5214
     3	-36.3714
     4	-45.4479
     5	-48.0307
     6	-48.7147
     7	-47.8853
     8	-46.6952
     9	-53.3751
     10	-52.3716
     11	-48.7485
     12	-51.7934
     13	-44.3183
     14	-53.0196
     15	-50.5134
     16	-53.1481
     17	-48.1244
     18	-47.787
     19	-49.4366
     20	-47.6917
     21	-47.8654
     22	-50.9595
     23	-49.2512
     24	-51.4467
     25	-50.8832
     26	-40.1298
     27	-49.7144
     28	-51.6124
     29	-52.6363
     30	-43.0921
     31	-48.0059
     32	-50.2139
     33	-47.9583
     34	-47.411
     35	-50.9434
     36	-53.6684
     37	-44.6689
     38	-51.6912
     39	-53.5678
     40	-50.0215
     41	-48.204
     42	-48.2751
     43	-51.0128
     44	-52.8112
     45	-48.3422
     46	-43.4643
     47	-47.5391
     48	-50.4429
     49	-50.5254
     50	-54.2783
     51	-49.3236
     52	-50.7506
     53	-51.4321
     54	-50.5833
     55	-52.9856
     56	-53.77
     57	-46.8769
     58	-46.9255
     59	-52.4172
     60	-50.964
     61	-50.5665
     62	-52.0168
     63	-49.2695
     64	-43.7435
     65	-50.2724
     66	-50.1373
     67	-48.2466
     68	-53.1438
     69	-50.5922
     70	-46.4332
     71	-51.0936
     72	-48.5968
     73	-48.8805
     74	-49.1299
     75	-48.5043
     76	-48.0144
     77	-49.1753
     78	-57.0485
     79	-51.0807
     80	-47.7196
     81	-51.0695
     82	-48.8323
     83	-50.8741
     84	-47.0493
     85	-51.5255
     86	-56.2889
     87	-48.9515
     88	-46.2357
     89	-47.1358
     90	-48.5066
     91	-50.1665
     92	-44.1657
     93	-49.9785
     94	-50.5638
     95	-47.9809
     96	-51.3012
     97	-48.2492
     98	-54.1664
     99	-48.6446
     100	-54.9343
     101	-53.66
     102	-51.2418
     103	-46.872
     104	-52.0475
     105	-50.6886
     106	-49.4763
     107	-55.9026
     108	-54.7954
     109	-44.1464
     110	-48.1967
     111	-46.257
     112	-48.2709
     113	-49.4791
     114	-55.124
     115	-50.3997
     116	-51.4293
     117	-59.3333
     118	-56.8759
     119	-48.9705
     120	-48.9772
     121	-46.7002
     122	-48.4843
     123	-52.0493
     124	-49.4655
     125	-49.4618
     126	-42.5104
     127	-46.6317
     128	-50.0321
     129	-51.2214
     130	-56.5567
     131	-52.945
     132	-51.0209
     133	-53.2828
     134	-54.099
     135	-51.6668
     136	-52.5809
     137	-50.33
     138	-57.718
     139	-51.5968
     140	-50.1121
     141	-53.7579
     142	-57.0954
     143	-53.8787
     144	-45.7346
     145	-53.1578
     146	-54.0137
     147	-50.4715
     148	-56.1738
     149	-48.3563
     150	-49.7892
     151	-50.6667
     152	-49.8319
     153	-49.8551
     154	-55.0025
     155	-56.2075
     156	-54.1293
     157	-47.7355
     158	-51.6997
     159	-54.9316
     160	-51.1406
     161	-44.481
     162	-46.754
     163	-55.4517
     164	-48.4092
     165	-45.7508
     166	-45.7267
     167	-48.3592
     168	-49.5059
     169	-53.379
     170	-51.1818
     171	-54.5743
     172	-49.0745
     173	-49.0509
     174	-48.8907
     175	-48.5935
     176	-46.8176
     177	-44.4825
     178	-49.0936
     179	-51.0238
     180	-48.2486
     181	-51.271
     182	-46.9064
     183	-50.076
     184	-51.1754
     185	-47.8776
     186	-47.4335
     187	-51.4389
     188	-45.7688
     189	-46.7081
     190	-46.3274
     191	-44.9674
     192	-48.5076
     193	-53.2341
     194	-56.0897
     195	-49.455
     196	-49.2397
     197	-44.2719
     198	-47.3803
     199	-51.757
Table[1]
     200	-2.92195
     201	-4.34472
     202	7.9088
     203	1.59877
     204	-5.84475
     205	1.3718
     206	-0.811399
     207	1.28305
     208	-1.96917
     209	2.62953
     210	-1.51321
     211	-1.52876
     212	-2.60754
     213	-2.7589
     214	0.57361
     215	-1.13896
     216	-1.71683
     217	-1.68824
     218	1.47031
     219	-2.67293
     220	0.115342
     221	3.10107
     222	-0.908284
     223	-1.12871
     224	-0.837544
     225	0.203887
     226	1.16277
     227	2.16046
     228	-0.0544563
     229	1.28793
     230	8.0388
     231	-3.78242
     232	1.64791
     233	-1.93224
     234	-5.31764
     235	-0.985437
     236	-0.472555
     237	0.76844
     238	-2.90512
     239	2.37592
     240	1.42087
     241	-1.75345
     242	-5.02175
     243	3.81046
     244	0.369908
     245	0.948802
     246	0.662487
     247	-2.34372
     248	2.18163
     249	4.46669
     250	-0.329907
     251	-2.86627
     252	0.0536559
     253	0.228065
     254	1.71825
     255	5.49655
     256	2.21127
     257	-3.9613
     258	-3.55955
     259	3.63743
     260	-0.20251
     261	-5.9997
     262	5.7217
     263	-0.656522
     264	-3.95076
     265	1.50664
     266	-1.11293
     267	-2.80075
     268	-3.44006
     269	-3.09413
     270	6.09038
     271	0.976819
     272	1.77871
     273	-2.38342
     274	-0.305332
     275	0.905309
     276	4.18212
     277	2.60173
     278	-2.02888
     279	4.98905
     280	6.68565
     281	-2.6741
     282	4.84576
     283	2.68951
     284	2.6228
     285	-1.91521
     286	2.61355
     287	-1.01709
     288	-0.336861
     289	1.54685
     290	-2.86563
     291	-0.907062
     292	-2.11943
     293	1.89301
     294	-0.536777
     295	-8.17238
     296	-4.77668
     297	6.56567
     298	0.26962
     299	-1.68811
     300	-1.90068
     301	0.34508
     302	3.95192
     303	-0.728926
     304	1.56125
     305	0.0806611
     306	0.245732
     307	-2.0148
     308	-1.69024
     309	-2.7257
     310	-0.890567
     311	4.1226
     312	4.44221
     313	-8.04163
     314	3.41392
     315	-3.2167
     316	-1.42271
     317	-1.07614
     318	0.912057
     319	-1.21545
     320	2.0762
     321	-0.0225722
     322	1.20504
     323	-1.29817
     324	2.94811
     325	-1.51383
     326	2.13756
     327	3.99463
     328	3.12869
     329	-0.321061
     330	-2.73394
     331	-3.4137
     332	5.61272
     333	-2.16095
     334	-4.93361
     335	0.525902
     336	-1.73752
     337	-5.58911
     338	-1.7108
     339	0.154274
     340	1.55955
     341	0.409572
     342	-3.00971
     343	-4.13514
     344	1.9507
     345	2.71577
     346	-4.96309
     347	-0.110848
     348	-1.42427
     349	-3.30063
     350	-1.05048
     351	1.07738
     352	-1.46188
     353	0.922515
     354	4.33934
     355	-2.81494
     356	-2.10239
     357	2.08945
     358	0.477516
     359	5.00287
     360	-3.55963
     361	-2.05141
     362	3.57384
     363	1.56325
     364	3.85689
     365	-2.17199
     366	2.72356
     367	4.61206
     368	2.57464
     369	0.560422
     370	4.97427
     371	0.0478507
     372	-5.42315
     373	0.514088
     374	0.717833
     375	-0.758897
     376	-3.1755
     377	-5.72692
     378	-0.756768
     379	-3.5582
     380	-4.90604
     381	-6.10331
     382	-0.502797
     383	3.09598
     384	1.98153
     385	-2.60599
     386	-5.55868
     387	-1.9863
     388	2.45726
     389	-4.01541
     390	-0.525973
     391	-2.7322
     392	-0.206147
     393	-0.412995
     394	3.15333
     395	0.770543
     396	1.72107
     397	-2.65519
     398	2.15872
     399	0.729241
     400	-0.0905532
     401	-0.766385
     402	-2.67467
     403	-2.89218
     404	-1.63615
     405	-1.35973
     406	2.81353
     407	2.56714
     408	0.363525
     409	-2.73775
     410	-2.23986
     411	2.38708
     412	2.63184
     413	1.80615
     414	-1.4083
     415	2.17974
     416	0.0144896
     417	0.0092946
     418	-5.14309
     419	6.99352
     420	-2.38861
     421	-2.56485
     422	-6.94993
     423	-2.55867
     424	2.43975
     425	1.55157
     426	-1.73443
     427	2.12182
     428	-3.76086
     429	-0.489185
     430	1.27411
     431	-3.92487
     432	-5.36595
     433	1.80887
     434	-3.21689
     435	-0.658408
     436	4.25024
     437	-2.37971
     438	-0.700849
     439	-1.78248
     440	4.11031
     441	0.657586
     442	-4.8263
     443	2.03736
     444	-0.116081
     445	0.658695
     446	-5.39686
     447	3.66709
     448	2.70345
     449	-2.82548
     450	3.28067
     451	2.9049
     452	0.200489
     453	-0.339768
     454	-2.26415
     455	-0.859511
     456	1.48667
     457	-0.922154
     458	2.0078
     459	-3.4651
     460	-0.529343
     461	-0.202698
     462	1.66668
     463	-2.22484
     464	-1.43948
     465	-0.174596
     466	1.37627
     467	-2.47439
     468	2.74704
     469	-4.19825
     470	0.475422
     471	-0.902795
     472	-1.96388
     473	-0.842048
     474	2.10587
     475	1.27579
     476	0.714839
     477	1.67816
     478	4.14901
     479	1.87643
     480	2.25583
     481	-3.58776
     482	2.43742
     483	-3.99507
     484	1.47081
     485	1.87625
     486	1.23376
     487	-2.04462
     488	-0.612967
     489	3.00301
     490	1.3963
     491	-4.98965
     492	0.620377
     493	-4.25943
     494	2.91028
     495	-0.401532
     496	-1.61124
     497	2.67809
     498	-4.56348
     499	-2.44084
Table[2]
     500	47.2083
     501	50.1333
     502	49.0593
     503	53.7273
     504	47.6224
     505	49.257
     506	52.8288
     507	52.2066
     508	46.5897
     509	49.7019
     510	45.9157
     511	51.1717
     512	51.6923
     513	50.9892
     514	53.3416
     515	44.6859
     516	49.9599
     517	49.4448
     518	51.5572
     519	57.286
     520	46.1973
     521	53.1693
     522	52.7945
     523	49.1164
     524	50.6427
     525	50.8951
     526	46.6377
     527	46.7923
     528	51.7048
     529	48.3918
     530	50.0331
     531	48.0992
     532	51.5465
     533	52.9936
     534	49.6345
     535	46.3751
     536	48.6749
     537	49.0931
     538	50.5678
     539	51.719
     540	48.6139
     541	49.4551
     542	51.3585
     543	51.7726
     544	51.6984
     545	50.0278
     546	51.0227
     547	44.5653
     548	44.4812
     549	49.4814
     550	50.1419
     551	48.39
     552	49.0482
     553	46.9835
     554	56.8812
     555	52.0208
     556	52.2174
     557	50.2577
     558	56.1442
     559	49.1339
     560	49.8706
     561	47.0276
     562	51.5911
     563	48.5077
     564	53.4353
     565	46.4436
     566	56.6919
     567	48.1306
     568	51.5642
     569	46.1931
     570	52.4438
     571	49.3824
     572	51.4996
     573	48.6297
     574	48.2575
     575	55.4545
     576	49.863
     577	53.3934
     578	48.6832
     579	50.1713
     580	41.2899
     581	55.051
     582	48.3795
     583	51.2265
     584	53.7981
     585	47.3477
     586	53.0892
     587	51.0884
     588	46.7424
     589	47.2892
     590	51.3276
     591	52.846
     592	51.8075
     593	51.1482
     594	47.0198
     595	48.6325
     596	51.3365
     597	47.4598
     598	50.1481
     599	50.8701
     600	48.833
     601	52.3115
     602	50.2489
     603	51.5298
     604	48.3827
     605	50.7647
     606	47.5093
     607	49.6233
     608	46.8448
     609	49.0365
     610	51.4585
     611	52.4883
     612	49.5769
     613	51.2774
     614	52.7116
     615	43.2847
     616	52.3179
     617	52.643
     618	50.3444
     619	56.2518
     620	49.624
     621	44.7577
     622	54.6789
     623	51.8825
     624	44.4774
     625	53.2974
     626	46.8067
     627	50.6386
     628	53.064
     629	50.1626
     630	50.4756
     631	49.3542
     632	52.253
     633	43.6257
     634	50.8882
     635	50.6808
     636	52.8739
     637	50.1468
     638	47.2162
     639	53.8351
     640	49.0978
     641	43.4824
     642	49.6421
     643	50.7634
     644	44.3081
     645	47.4656
     646	48.4338
     647	48.6757
     648	57.5019
     649	50.6206
     650	46.8537
     651	48.3813
     652	51.7982
     653	48.424
     654	47.9283
     655	47.0337
     656	48.0934
     657	55.6896
     658	51.2767
     659	49.4945
     660	48.7087
     661	49.333
     662	46.3045
     663	49.2549
     664	49.2371
     665	50.2047
     666	49.6595
     667	45.2564
     668	50.1154
     669	48.3832
     670	46.3732
     671	52.6561
     672	52.6019
     673	47.3666
     674	47.2436
     675	48.1765
     676	52.4576
     677	46.0745
     678	51.215
     679	49.7578
     680	48.5021
     681	49.5451
     682	54.2666
     683	47.0515
     684	54.1405
     685	49.7892
     686	48.8238
     687	47.2997
     688	53.0064
     689	46.0247
     690	46.9157
     691	53.6858
     692	49.7301
     693	46.5576
     694	52.9799
     695	50.7041
     696	47.4629
     697	50.6511
     698	48.1133
     699	49.5037
     700	53.3731
     701	53.5353
     702	52.5137
     703	51.5942
     704	50.0586
     705	47.3795
     706	49.255
     707	49.0336
     708	45.1589
     709	46.653
     710	46.0391
     711	51.6132
     712	47.142
     713	46.0209
     714	50.1332
     715	47.2314
     716	48.2204
     717	50.8915
     718	49.3718
     719	50.5643
     720	57.0949
     721	53.9162
     722	50.0081
     723	47.6147
     724	49.8864
     725	50.7303
     726	51.9253
     727	49.2549
     728	48.3517
     729	52.6285
     730	50.6274
     731	50.4569
     732	52.1017
     733	52.6353
     734	53.4939
     735	47.1914
     736	46.435
     737	53.0841
     738	52.3783
     739	51.794
     740	53.5775
     741	47.12
     742	52.6386
     743	48.9676
     744	54.6517
     745	51.1574
     746	56.0696
     747	51.5449
     748	50.2992
     749	50.3166
     750	50.6935
     751	49.6394
     752	55.1366
     753	53.7863
     754	56.0628
     755	47.6918
     756	53.7175
     757	45.8798
     758	54.4283
     759	44.7231
     760	50.4676
     761	49.3085
     762	51.0676
     763	41.6511
     764	47.3657
     765	52.1909
     766	49.2877
     767	45.1008
     768	46.9269
     769	51.402
     770	53.2275
     771	50.0152
     772	51.4946
     773	51.6075
     774	54.9089
     775	52.0244
     776	46.7912
     777	52.2315
     778	47.2325
     779	49.543
     780	48.6524
     781	52.4153
     782	57.3265
     783	51.7749
     784	51.2529
     785	48.7739
     786	45.286
     787	50.803
     788	43.1362
     789	52.3826
     790	49.5456
     791	51.5356
     792	56.1188
     793	50.9528
     794	52.6905
     795	46.11
     796	49.708
     797	48.8488
     798	49.9329
     799	49.7791
     800	54.4539
     801	52.1736
     802	48.8191
     803	52.657
     804	45.2783
     805	49.6953
     806	50.6557
     807	43.0968
     808	50.3708
     809	46.758
     810	46.0757
     811	49.9069
     812	46.7772
     813	47.993
     814	50.7924
     815	51.1395
     816	46.9231
     817	45.6459
     818	57.4477
     819	48.6586
     820	49.6964
     821	52.6133
     822	46.233
     823	51.3046
     824	52.3569
     825	52.1927
     826	49.2518
     827	54.9871
     828	51.7533
     829	49.8286
     830	47.9469
     831	47.4476
     832	49.7752
     833	50.3336
     834	49.3951
     835	51.9032
     836	50.1919
     837	44.8847
     838	53.169
     839	52.465
     840	48.5866
     841	50.1508
     842	54.2019
     843	51.2948
     844	53.8337
     845	51.8688
     847	51.5235
     848	51.0933
     849	50.3949
     850	48.83
     851	49.6694
     852	51.8897
     853	54.6055
     854	48.4679
     855	46.5371
     856	48.4887
     857	49.6518
     858	50.8344
     859	51.7577
     860	52.6648
     861	49.59
     862	52.0392
     863	48.3213
     864	56.1552
     865	47.8037
     866	47.5686
     867	52.6098
     868	50.7106
     869	46.8837
     870	48.2928
     871	54.3736
     872	48.7201
     873	47.812
     874	49.4399
     875	50.5798
     876	48.4775
     877	50.0702
     878	49.4787
     879	52.8912
     880	45.0419
     881	54.1969
     882	50.5366
     883	48.0548
     884	53.0176
     885	50.3011
     886	51.2065
     887	49.026
     888	48.6936
     889	55.7794
     890	45.8426
     891	51.2555
     892	52.6447
     893	50.487
     894	52.2281
     895	54.8323
     896	50.5128
     897	53.5178
     898	49.4288
     899	51.2999
     900	50.493
     901	45.0754
     902	52.88
     903	52.8198
     904	44.4267
     905	50.2931
     906	47.0122
     907	46.9598
     908	47.0083
     909	48.7758
     910	46.3166
     911	45.1931
     912	53.9486
     913	47.4728
     914	48.1687
     915	42.162
     916	56.9081
     917	48.436
     918	50.1222
     919	55.0411
     920	49.3187
     921	45.4149
     922	53.0721
     923	49.8275
     924	45.1732
     925	45.2809
     926	45.6984
     927	51.3946
     928	50.3734
     929	51.4407
     930	48.4081
     931	55.5197
     932	52.9908
     933	49.5517
     934	44.6276
     935	50.1467
     936	52.5191
     937	47.1686
     938	44.0294
     939	51.7616
     940	48.5467
     941	49.3564
     942	45.8278
     943	50.4393
     944	48.3418
     945	50.4748
     946	50.1085
     947	49.3994
     948	51.1022
     949	49.1833
     950	54.722
     951	49.7215
     952	52.0236
     953	48.1955
     954	46.7201
     955	49.4768
     956	49.5986
     957	48.2752
     958	45.9778
     959	46.5852
     960	51.9321
     961	49.6427
     962	46.9841
     963	55.7479
     964	52.9344
     965	53.8702
     966	52.2877
     967	49.9191
     968	47.2089
     969	54.3887
     970	51.0311
     971	52.5679
     972	46.8479
     973	50.1765
     974	55.1074
     975	58.2637
     976	50.2971
     977	42.7912
     978	50.6414
     979	50.0517
     980	52.4977
     981	48.4231
     982	47.6988
     983	46.7652
     984	50.5972
     985	52.088
     986	55.6902
     987	48.426
     988	48.8053
     989	50.6716
     990	46.6228
     991	47.3814
     992	51.1737
     993	48.748
     994	50.0627
     995	55.2962
     996	45.3248
     997	59.8706
     998	50.0316
     999	44.0382
Table[3]
     0	-30.9712
     1	-31.307
Table[4]
     846	58.531

2012-02-29

ギブスサンプリングによるベイズ推定

はじめに

MCMCによるベイズ推定として、正規分布に従うデータが与えられたとき、その正規分布のパラメータ(平均と分散)が従う分布および推定値を求める。

尤度関数が正規分布の場合、共役事前分布はそれぞれ、平均は正規分布、分散は逆ガンマ分布になるので、ギブスサンプリングを使うことができ、これでパラメータの推定する。

コード

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
static const double PI = 3.14159265358979323846264338;

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); 
}
//ガンマ乱数
double gamma_rand(double shape, double scale){
  double n, b1, b2, c1, c2;
  if(4.0 < shape) n = 1.0/sqrt(shape);
  else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6;
  else if(0.0 < shape) n = 1.0/shape;
  else return -1;

  b1 = shape - 1.0/n;
  b2 = shape + 1.0/n;

  if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0;
  else c1 = 0;
  c2 = b2 * (log(b2)-1.0) / 2.0;

  while(true){
    double v1 = frand(), v2 = frand();
    double w1, w2, y, x;
    w1 = c1 + log(v1);
    w2 = c2 + log(v2);
    y = n * (b1*w2-b2*w1);
    if(y < 0) continue;
    x = n * (w2-w1);
    if(log(y) < x) continue;
    return exp(x) * scale;
  }
  return -1;
}
//逆ガンマ乱数
double invgamma_rand(double shape, double scale){
  return 1.0/gamma_rand(shape, scale);
}
//正規乱数
double normal_rand(double mu, double sigma2){
  double sigma = sqrt(sigma2);
  double u1 = frand(), u2 = frand();
  double z1 = sqrt(-2*log(u1)) * cos(2*PI*u2);
  //double z2 = sqrt(-2*log(u1)) * sin(2*PI*u2);
  return mu + sigma*z1;
}




//正規分布に従うデータに対するギブスサンプリング
class Gibbs_Normal {
  double mu, sigma2; //未知パラメータの平均muと分散sigma2
  double n, x_bar; //データ数n、データの平均値x_bar
  //事後分布を見やすくするための自由度を与える変数など
  double m_0, n_0, S_0;
  double m_1, n_1, mu_1, nS_1; 

public:
  Gibbs_Normal(double mu_0, double sigma2_0, double alpha_0, double lambda_0){
    mu = mu_0;
    m_0 = 1.0;
    sigma2 = sigma2_0 / m_0;
    n_0 = alpha_0 * 2.0;
    S_0 = lambda_0 * 2.0 / n_0;
  }
  
  void set_data(const std::vector<double>& data){
    n = data.size();
    x_bar = 0.0;
    for(int i=0; i<n; i++) x_bar += data[i];
    x_bar /= n;
    m_1 = m_0 + n;
    n_1 = n_0 + n;
    mu_1 = (m_0 * mu + n * x_bar) / (m_0 + n);
    nS_1 = n_0 * S_0;
    for(int i=0; i<n; i++) nS_1 += (data[i] - x_bar) * (data[i] - x_bar);
    nS_1 += (m_0 * n) / (m_0 + n) * (x_bar - mu) * (x_bar - mu);

    //std::cerr << "m_0 : " << m_0 << std::endl;
    //std::cerr << "m_1 : " << m_1 << std::endl;
    //std::cerr << "n_1 : " << n_1 << std::endl;
    //std::cerr << "n_1*S_1 : " << nS_1 << std::endl;
    //std::cerr << "mu_1 : " << mu_1 << std::endl;
    //std::cerr << "x_bar : " << x_bar << std::endl;
  }

  void sampling(){
    mu = normal_rand(mu_1, sigma2/m_1);
    sigma2 = invgamma_rand((n_1+1)/2.0, 2.0/(nS_1+m_1*(mu-mu_1)*(mu-mu_1)));
                                                                  //精度を渡したいので逆数にする
  }

  double get_mu_mean(){ return mu; }
  double get_mu_var(){ return sigma2/m_1; }
  double get_sigma2_shape(){ return (n_1+1)/2.0; }
  double get_sigma2_scale(){ return (nS_1+m_1*(mu-mu_1)*(mu-mu_1))/2.0; }
};


int main(){

  Gibbs_Normal gn(0, 1000, 0.001, 0.001); //事前分布のパラメータ

  //データ生成(適当に平均5,分散4の正規分布に従うデータを1000個作成)
  std::vector<double> data;
  for(int i=0; i<1000; i++){
    double r = normal_rand(5, 4);
    data.push_back(r);
  }
  
  //生成したデータの詳細
  double ave_data = 0.0, sigma2_data = 0.0;
  for(int i=0; i<data.size(); i++){
    ave_data += data[i];
  }
  ave_data /= data.size();
  for(int i=0; i<data.size(); i++){
    sigma2_data += (data[i]-ave_data)*(data[i]-ave_data);
  }
  sigma2_data /= data.size();
  std::cerr << "data ave: " << ave_data << std::endl;
  std::cerr << "data sigma2: " << sigma2_data << std::endl;

  //ギブスサンプリング
  int cnt = 0;
  double ave_mu_mean = 0.0; //平均の事後分布の平均の平均値
  double ave_mu_var = 0.0; //平均の事後分布の分散の平均値
  double ave_sigma2_shape = 0.0; //分散の事後分布のshapeの平均値
  double ave_sigma2_scale = 0.0; //分散の事後分布のscaleの平均値

  int iter_N = 1000000; //イテレーション回数
  int burnin = 300000; //バーンイン期間

  gn.set_data(data); //データをセット

  for(int t=0; t<iter_N; t++){
    gn.sampling(); //サンプリング
    if(t >= burnin){
      ave_mu_mean += gn.get_mu_mean();
      ave_mu_var += gn.get_mu_var();
      ave_sigma2_shape += gn.get_sigma2_shape();
      ave_sigma2_scale += gn.get_sigma2_scale();
      cnt++;
    }
  }

  ave_mu_mean /= cnt;
  ave_mu_var /= cnt;
  ave_sigma2_shape /= cnt;
  ave_sigma2_scale /= cnt;

  //事後分布
  std::cerr << "-------" << std::endl;
  std::cerr << "N(" << ave_mu_mean << "," << ave_mu_var << ")" << std::endl;
  std::cerr << "IG(" << ave_sigma2_shape << "," << ave_sigma2_scale << ")" << std::endl;

  //事後分布の平均値(期待値)
  std::cerr << "-------" << std::endl;
  std::cerr << "mean : " << ave_mu_mean << std::endl;  
  std::cerr << "var : " << ave_sigma2_scale/(ave_sigma2_shape-1) << std::endl;

  return 0;
}



結果

#データの平均分散
data ave: 5.01414
data sigma2: 4.04538
-------
#パラメータの事後分布
N(5.00913,0.00407451)
IG(500.501,2037.29)
-------
#事後分布の代表値(平均値)
mean : 5.00913
var : 4.07864

参考文献