Negative/Positive Thinking

2018-01-20

焼きなまし法の適用メモ

はじめに

焼きなまし法について、問題へ適用する際のメモ。

焼きなまし法とは

疑似コード

x:=初期解, T:=初期温度, R:=初期イテレーション回数
while 終了条件 do begin
  for i:=1 to R do begin
    y:=近傍解の一つ(y∈N(x))
    Δ:=score(y) - score(x)
    if Δ >= 0 then x:=y; else
      if exp(-Δ/T) >= frand() then x:=y
  end
  T:=CalcTemp(T)
  R:=CalcIter(R)
end
return 解;

c++でのコーディング例: 診断人さん, 焼きなまし法のコツ, http://shindannin.hatenadiary.com/entry/20121224/1356364040

重要なパラメータ

近傍定義
  • 解xの近傍集合N(x)をどのように定義するか?で解空間も変化しうる
  • 時間制約などがある場合は特に効率的な探索ができるよう解空間の構造、近傍サイズなども考慮
    • すべての実行解は、任意の実行解から到達可能、など
  • 効率的な探索を行えるようにするためには以下のような構造の近傍を避けたり、別な探索手法など検討
    • とげとげ、ぎざぎざ
    • 深いくぼみ
    • 平坦

  • chokudai先生の焼きなまし講座, https://togetter.com/li/607979
  • tsukammoさん, 競技プログラミングにおいて焼きなまし法に堕ちずに落とすコツ, https://qiita.com/tsukammo/items/b410f3202372fe87c919

冷却スケジュール
  • 初期温度T
    • 十分高い温度から始めるべき
    • ただし、場合によっては良い初期解&低い温度にしてもよい可能性もある?
  • 温度の減少関数CalcTemp
    • 十分ゆっくり冷却させる
    • 冷却速度が時間に依る/依らない関数、複雑な速度変化の関数など様々考えられる
    • 例:T_k = r^k * 初期温度(ただし、r=0.8〜0.99程度)、T_k = 初期温度 / log_2(k+2)、など
  • 各温度での反復回数R
    • 近傍をまんべんなく探索できることが求められるが、近傍定義に依るのでそっちのほうが重要
  • 終了条件
    • 制限時間、解のスコアの変動がなくなる(小さくなる)、など

細かいチューニング

焼きなまし法適用するべきかどうか

問題ごとの考察が重要だが、以下の概念や知見が役に立つ。

文脈
問題のタイプ
有名アルゴリズム適用に囚われない

その他、概念とか用語とか

疑似平衡
  • 理論的な解析として、解の移動系列を確率過程としてのマルコフ連鎖と捉えて収束性が議論されるらしい
    • 反復は定常状態になるまで行える場合、かつ、温度をT→0にしていく場合、確率的に最適解に収束することが示される
    • 反復回数の限度を与える場合、温度パラメータの系列が十分ゆっくり冷却されるならば、収束が保証される
  • しかし、理論的に必要となる計算時間は膨大になりすぎてしまうので、実用的には(理論的枠組みからは外れるが)「各温度で、できるかぎり平衡に近い状態(疑似平衡)」を得られるようにチューニングすることが重要なポイントになるよう

受理率
  • ある温度Tでの全反復数に対して、実際に受理され移動が発生した回数の比率
    • η(T) = (温度Tでの移動回数) / (温度Tでの反復回数)
  • 受理率が低い場合、山登りと同様に局所最適解からの脱出が困難になってしまう
  • 場合によると思われるが、0.3〜0.5ぐらいだと効率の良い探索が行われるとのこと

近接最適性原理
  • 「良い解同士は何らかの類似構造を持っている」という経験的な原理
    • ある良い解と別の良い解の共通要素が多い、みたいな場合もこの原理が成り立つ
  • この原理が成り立つような解の偏りがあるならば、この類似構造を活用して効率的に最適解を探索できる可能性がある

集中化と多様化
  • 集中化
    • 「良い解の近傍により良い解が存在する」をもとに、良い解の近くを重点的に探索する戦略
  • 多様化
    • より遠くにある良い解を見つける戦略
  • 焼きなましの場合、温度や近傍構造などでこの集中化・多様化のチューニングや工夫をすることができる

参考

2018-01-16

RangeCoderを試す

はじめに

RangeCoderのメモ。
中途半端な理解で適当に書き写そうとしたらひどいことになったので、まとめておく。。。

RangeCoderとは

  • エントロピー符号の一種
  • シンボル列をある数値で表現する

以下、桁上りありの場合について下記サイトを参考に作成してみた。
http://www.geocities.jp/m_hiroi/light/pyalgo36.html

他にも「桁上りなし」や「適応型」などがあるよう。
http://www.geocities.jp/m_hiroi/light/pyalgo37.html

アルゴリズム

動作原理

f:id:jetbead:20180116060906p:image

  • ある範囲に対し、シンボルの出現確率で区分けして、該当するシンボルの範囲を次の範囲として、シンボルを割り当て続ける
  • 最終的に、範囲の左端の数値さえわかれば、同じ処理でシンボルを復元できる

範囲の拡大
  • かなり大きい範囲を用意しておかないとすぐに範囲が狭まってしまってシンボルの出現確率で区分けするに十分な整数がなくなってしまう
  • そこで、ある程度の大きさの範囲からスタートし、ある程度の大きさ以下になったら全体をx倍するような処理を繰り返すことで、これに対処する

f:id:jetbead:20180116060907p:image

実装方法
  • 上記は多倍長などを用いなくても、普通の整数型を使って、都度、上位byteを出力してしまって、範囲を表す変数を常にある程度の大きさに収まるようにすることで実装できる


  • 範囲の最大値max_range、拡大する範囲の閾値min_range、範囲の左端をlow、範囲の幅をrangeとすると、上記で言っているのは「lowをmax_range以下に保つように実装する」ということ
  • ここでは、以下のように定める
    • max_rangeのバイト数をrange_byteとする
    • min_rangeのバイト数はrange_byte-1
    • 拡大率は0x100(256)倍

拡大処理に伴う問題
  • 拡大処理を行うときに上位byteの出力に伴う問題がいくつかある

f:id:jetbead:20180116060908p:image

  • 図で、low=0x123456のときrangeが小さくなりすぎて拡大したとする(lowとrangeをそれぞれ0x100倍)
  • lowは0x12345600となるが、max_range以下に収まるようにしたいので、上位byteの0x12を出力してしまって、下位の0x345600を次のlowとする

  • 問題は「出力済みの桁の繰り上げ(桁上り)が発生しうる」こと

桁上り
  • 図で、単純に上位byteの0x12を出力してしまうと、次の範囲処理でmax_range=0x1000000を超えた場合、実は出力すべきは0x13だった、、、ということが起きうる
  • また、拡大処理した際に出力byteが0xffだとさらに出力をさかのぼって桁上りを処理しなければならない
    • 出力済みが「0x12, 0xff, 0xff, 0xff」で桁上りが発生したら「0x13,0x00,0x00,0x00」にしなければならない

  • これの解決方法の一つとして、実際に出力はせずに「バッファにためる」ことで確定するまで出力を保留する方法がある
    • 出力候補の先頭byteをbuff、後続の0xffの個数をcntとする

  • 一つ目の「次の範囲処理max_rangeを超える場合」は、超えた場合にはrangeがmax_range以下であるため、桁上り分は+1しかならないので、buff++をして、lowはmax_range以下の部分だけにマスクすればよい
  • このとき、cntが1個以上の場合は「0x12, 0xff, ..., 0xff」を+1すると「0x13, 0x00, ..., 0x00」のような状態に変化するので、最後の「0x00」以外は出力してしまって問題ない

  • 二つ目の「拡大処理した際」は、出力候補のlowの上位bit部分が0xffなら桁上りがまだあり得るので出力はせずにcnt+1だけ、0xff未満ならbuffとcnt個分の0xffを出力してbuffにその値をいれてあげればよい

終了処理
  • 残っているbuff,cnt,lowをすべて出力して終了

シンボルの出現頻度テーブル
  • デコード処理ではシンボルの出現確率で区分けするために出現頻度テーブルも保存しておく

  • 注意として、出現頻度の合計値はmin_range以下である必要がある
    • そうでないとrangeを区分に分割するときに整数に割り当てられない場合ができてしまう

  • 実装上は、頻度値をshort型などで持たせることで、頻度合計値をmin_range以下にしたり、テーブル保存時の容量削減をすると良いよう

デコード処理
  • 出現頻度テーブルと出力した数値を頭から読み込んでいけばよい
  • エンコード処理と同様に、rangeが小さくなりすぎたら拡大しながら読み込んでいく

コード

いくつかのデータで復元できているので、おそらく大丈夫。

#include <iostream>
#include <vector>

class RangeCoder {
  const int64_t range_byte = 6; //max_rangeのバイト数
  const int size = 0x10000; //出現するデータの種類数の最大値(入力シンボルをuint16_tにしているため)
  
  const int64_t max_range = 0x1LL << (8 * range_byte); //rangeの最大幅
  const int64_t min_range = 0x1LL << (8 * (range_byte-1)); //rangeを拡大する幅の閾値
  const int64_t mask = max_range - 1; //max_range分のマスク(0xfff...fff)
  const int64_t shift = 8 * (range_byte-1); //出力分計算用(バッファ処理用「0xff00...00」生成と1byteずつの出力するため用)

  std::vector<int32_t> count, count_sum; //頻度分布、頻度累積分布
  int32_t sum; //頻度合計値

  int32_t orig_size; //入力シンボル数保存用
  int pos; // 出力位置保存用

  void init(){
    for(int i=0; i<size; i++){
      count[i] = 0;
      count_sum[i] = 0;
    }
    pos = 0;
  }

  //頻度分布、頻度累積分布の作成
  void make_count(const std::vector<uint16_t>& in){
    init();

    for(uint16_t ch : in){
      int chi = ch & 0xffff;
      count[chi]++;
    }

    //頻度値を16bitに抑えるため、頻度値全体を1/2^n倍する
    int32_t max_count = 0;
    for(int i=0; i<size; i++){
      max_count = std::max(max_count, count[i]);
    }
    if(max_count > 0xffff){
      int n = 0;
      while(max_count > 0xffff){
        max_count >>= 1;
        n++;
      }
      for(int i=0; i<size; i++){
        if(count[i] > 0){
          count[i] >>= n;
          count[i] |= 1;
        }
      }
    }
    //頻度合計値はmin_range以下にしなければいけないので、
    //抑えるための処理(1/2を繰り返す)
    //注意: 無限ループに入りうる
    while(true){
      int32_t c_sum = 0;
      for(int i=0; i<size; i++){
        c_sum += count[i];
      }
      if(c_sum < min_range) break;
      for(int i=0; i<size; i++){
        if(count[i] > 0){
          count[i] >>= 1;
          if(count[i] == 0) count[i] = 1;
        }
      }
    }

    //頻度累積分布の作成
    sum = 0;
    for(int i=0; i<size; i++){
      count_sum[i+1] = sum + count[i];
      sum += count[i];
    }
  }

  //1byte分出力
  void putc(std::vector<uint8_t>& ret, uint8_t c){
    ret.push_back(c);
  }

  //1byte分読込
  int getc(const std::vector<uint8_t>& in){
    if(pos >= in.size()) return 0;
    return in[pos++] & 0xff;
  }

  //ヘッダ情報の出力
  void save_header(std::vector<uint8_t>& ret, int32_t orig_size){
    //シンボル数
    putc(ret, (orig_size>>24) & 0xff);
    putc(ret, (orig_size>>16) & 0xff);
    putc(ret, (orig_size>>8) & 0xff);
    putc(ret, orig_size & 0xff);

    //頻度分布
    int32_t num = 0; //シンボルのユニーク数
    for(int i=0; i<size; i++) if(count[i] > 0) num++;
    putc(ret, (num>>24) & 0xff);
    putc(ret, (num>>16) & 0xff);
    putc(ret, (num>>8) & 0xff);
    putc(ret, num & 0xff);

    for(int i=0; i<size; i++){
      if(count[i] > 0){
        //シンボル番号
        putc(ret, (i>>8) & 0xff);
        putc(ret, i & 0xff);
        //シンボルの出現頻度
        putc(ret, (count[i]>>8) & 0xff);
        putc(ret, count[i] & 0xff);
      }
    }
  }

  //ヘッダ情報の読込
  void load_header(const std::vector<uint8_t>& in){
    init();

    //シンボル数
    orig_size = getc(in); orig_size <<= 8;
    orig_size |= getc(in); orig_size <<= 8;
    orig_size |= getc(in); orig_size <<= 8;
    orig_size |= getc(in);

    //頻度分布
    int32_t num = getc(in); num <<= 8;
    num |= getc(in); num <<= 8;
    num |= getc(in); num <<= 8;
    num |= getc(in);

    for(int i=0; i<num; i++){
      //シンボル番号
      int32_t id = getc(in); id <<= 8;
      id |= getc(in);
      //シンボルの出現頻度
      int32_t cnt = getc(in); cnt <<= 8;
      cnt |= getc(in);

      count[id] = cnt;
    }

    //頻度累積分布の作成
    sum = 0;
    for(int i=0; i<size; i++){
      count_sum[i+1] = sum + count[i];
      sum += count[i];
    }
  }

  //デコード用範囲から該当するシンボルの探索
  int32_t search_code(int32_t val){
    int32_t i = 0;
    int32_t j = size-1;
    while(i < j){
      int32_t k = (i + j) / 2;
      if(count_sum[k+1] <= val){
        i = k + 1;
      }else{
        j = k;
      }
    }
    return i;
  }

public:
  RangeCoder():count(size+1, 0), count_sum(size+1, 0){}

  std::vector<uint8_t> encode(const std::vector<uint16_t>& in){
    std::vector<uint8_t> ret;

    make_count(in);
    save_header(ret, in.size());

    //桁上りバッファ用
    // buff: 出力候補
    // cnt: buff以降に続く出力候補0xffの数
    // => [0x12, 0xff, 0xff, ..., 0xff]のような情報
    int64_t buff = 0, cnt = 0;
    //範囲計算用
    int64_t low = 0, range = max_range; //下限と範囲

    
    for(int i=0; i<in.size(); i++){
      //入力シンボル
      int32_t ch = in[i] & 0xffff;
      //該当範囲の計算
      int64_t temp = range / sum;
      low += count_sum[ch] * temp;
      range = count[ch] * temp;

      //桁上りの処理
      //該当範囲がmax_rangeを超えてしまった場合
      if(low >= max_range){
        buff++;
        low &= mask;
        //もしcntが0より大きければ、[0x12,0xff,0xff...,0xff]のような状態なので、
        //buffが+1されたことで[0x13,0x00,0x00...,0x00]のようになり、最後の0x00以外は出力してよい
        if(cnt > 0){
          putc(ret, buff & 0xff); //buffの出力
          for(int j=0; j<cnt-1; j++) putc(ret, 0x00); //0x00をcnt-1個分出力、最後の0x00はbuffとする
          buff = 0x00;
          cnt = 0x00;
        }
      }
      //範囲が小さくなったら全体をを拡大(256倍)
      while(range < min_range){
        //拡大することでmax_rangeを超える上位1バイト分が、
        // - 0xffより小さければ上位8bitは0xffではないので、
        //   バッファを出力して、その上位8bit分をbuffにいれる(rangeは範囲内なので0xfe以下なら絶対に0xffまでしかいかない)
        // - 0xffの場合は、まだ桁上りがあるかもしれないので、0xffを増やす意味でcntを+1する
        if(low < (0xffLL << shift)){ //low < 0xff000...000
          putc(ret, buff & 0xff);
          for(int j=0; j<cnt; j++) putc(ret, 0xff);
          buff = (low >> shift) & 0xff;
          cnt = 0;
        }else{
          cnt++;
        }
        //全体を256倍に拡大する
        low = (low << 8) & mask;
        range <<= 8;
      }
    }
    //最後に残っている情報(buff, cnt, low)をすべて出力
    int32_t ch = 0xff;
    if(low >= max_range){
      buff++;
      ch = 0;
    }
    putc(ret, buff & 0xff);
    for(int j=0; j<cnt; j++) putc(ret, ch & 0xff);
    for(int j=shift; j>=0; j-=8){
      putc(ret, (low >> j) & 0xff);
    }

    return ret;
  }

  std::vector<uint16_t> decode(const std::vector<uint8_t>& in){
    std::vector<uint16_t> ret;

    load_header(in);
    getc(in);

    int64_t range = max_range;
    int64_t low = 0;
    for(int i=shift; i>=0; i-=8){
      low |= getc(in);
      if(i>0) low <<= 8;
    }

    while(ret.size() < orig_size){
      int64_t temp = range / sum;
      int32_t ch = search_code(low / temp);
      low -= temp * count_sum[ch];
      range = temp * count[ch];
      
      while(range < min_range){
        range <<= 8;
        low = ((low << 8) + getc(in)) & mask;
      }
      
      ret.push_back( ch & 0xffff );
    }
    return ret;
  }
};

int main(){
  RangeCoder encoder, decoder;

  //入力シンボル列
  std::vector<uint16_t> v;
  for(int i=0; i<10; i++){
    v.push_back(4);
    v.push_back(3);
    v.push_back(2);
    v.push_back(2);
    v.push_back(1);
    v.push_back(1);
    v.push_back(1);
    v.push_back(1);
  }

  //エンコード
  std::cout << "encoding" << std::endl;
  std::vector<uint8_t> encode_code = encoder.encode(v);

  std::cout << v.size() * 16 << " => " << encode_code.size() * 8 << std::endl;
  
  //デコード
  std::cout << "decoding" << std::endl;
  std::vector<uint16_t> decode_code = decoder.decode(encode_code);

  //結果の確認
  for(int i=0; i<decode_code.size(); i++){
    std::cout << i << "\t" << v[i] << "\t" << decode_code[i] << std::endl;
  }

  return 0;
}

結果

encoding
1280 => 384
decoding
0       4       4
1       3       3
2       2       2
3       2       2
4       1       1
5       1       1
6       1       1
7       1       1
8       4       4
...

1280bitのシンボル列を384bitに圧縮できているよう。
上記、書き換え不可能な出力ではなく、入れた後で書き換え可能な出力用vectorに入れているので、素直に桁上げしにいく処理にした方がコードは短くなりそう。

参考

2017-08-20

Elias-Fano Encodingで遊ぶ

はじめに

読んでる論文に出てきてたElias-Fano Encodingをちょっと書いて遊んでみた。

Elias-Fano Encodingとは

  • 単調増加整数列の表現方法のひとつ
  • 厳密にはsuccinctではないが、succinctに近い表現
    • quasi-succinct representationと言っている

  • 整数列の各値を「上位bit列」と「下位bit列」に分割し、整数列全体では、上位bit列を「(negated) unary code表現」、下位bit列を「各下位bitを連結したbit列」で表現したもの
    • 上位、下位のbit数は等分ではなく、上位bitがceil(lg(n))個
    • 例: {3,4,5}という整数列の場合
    • ceil(lg(3))=2として、上位2bitと下位1bitに分ける
    • 3は011、4は100、5は101なので、上位bitは01と10と10、下位bitは1と0と1
    • 上位bit列はnegated unary codeで表現するので、00が0回、01が1回、10が2回、11が0回なので、「0101100」と表現する
      • 1^{00の個数}0 1^{01の個数}0 1^{10の個数}0 ...と表現
    • 下位bit列は、連結するので「101」と表現

  • 転置インデクスやグラフ、Trieなどの表現として利用する論文が発表されているよう
  • 有用なコードもgithubで公開されているので実用する場合はそちらを参照

コード

http://pages.di.unipi.it/pibiri/slides/seminar_ef.pdf
を参考に「正の整数の単調増加列(1以上の整数、同じ整数はなし、ソート済み)」を想定した場合のElias-Fano Encodingするコードを書いてみた。
操作は「access: S[i]の値」と「successor(x): min{S[i] | S[i] >= x}」の2つ。
いくつかのランダムケースで試して問題なさそうなことまでしか確認していない。

#include <vector>
#include <map>
#include <cstdint>
#include <algorithm>
#include <iostream>
#include <cmath>

class FID {
  static const int BIT_SIZE = 64;
  using BLOCK_TYPE = uint64_t;
  int size;
  int block_size;
  std::vector<BLOCK_TYPE> blocks;
  std::vector<int> s_rank, s0_rank;
public:
  BLOCK_TYPE popcount(BLOCK_TYPE x){
    x = ((x & 0xaaaaaaaaaaaaaaaaULL) >> 1) + (x & 0x5555555555555555ULL);
    x = ((x & 0xccccccccccccccccULL) >> 2) + (x & 0x3333333333333333ULL);
    x = ((x & 0xf0f0f0f0f0f0f0f0ULL) >> 4) + (x & 0x0f0f0f0f0f0f0f0fULL);
    x = ((x & 0xff00ff00ff00ff00ULL) >> 8) + (x & 0x00ff00ff00ff00ffULL);
    x = ((x & 0xffff0000ffff0000ULL) >> 16) + (x & 0x0000ffff0000ffffULL);
    x = ((x & 0xffffffff00000000ULL) >> 32) + (x & 0x00000000ffffffffULL);
    return x;
  }
public:
  FID(int size):
  size(size),
  block_size(((size + BIT_SIZE - 1) / BIT_SIZE) + 1),
  blocks(block_size, 0),
  s_rank(block_size, 0),
  s0_rank(block_size, 0)
  {}

  void init(int sz){
    blocks.clear();
    s_rank.clear();
    size = sz;
    block_size = ((size + BIT_SIZE - 1) / BIT_SIZE) + 1;
    blocks.resize(block_size, 0);
    s_rank.resize(block_size, 0);
    s0_rank.resize(block_size, 0);
  }
  
  void set(int i){
    blocks[i/BIT_SIZE] |= 1ULL << (i%BIT_SIZE);
  }
  
  void finalize(){
    s_rank[0] = 0;
    for(int i=1; i<block_size; i++){
      s_rank[i] = s_rank[i-1] + popcount(blocks[i-1]);
    }
    s0_rank[0] = 0;
    for(int i=1; i<block_size; i++){
      s0_rank[i] = s0_rank[i-1] + (BIT_SIZE - popcount(blocks[i-1]));
    }
  }
  
  bool access(int i){
    return (blocks[i/BIT_SIZE] >> (i%BIT_SIZE)) & 1ULL;
  }

  int rank(int i){
    BLOCK_TYPE mask = (1ULL << (i%BIT_SIZE)) - 1;
    return s_rank[i/BIT_SIZE] + popcount(mask & blocks[i/BIT_SIZE]);
  }

  int select(int x){
    if(rank((block_size-1) * BIT_SIZE) <= x) return -1;
    int lb = 0, ub = block_size-1;
    while(ub-lb>1){
      int m = (lb+ub)/2;
      if(s_rank[m]<=x) lb = m;
      else ub = m;
    }
    int lbb = lb*BIT_SIZE, ubb = (lb+1)*BIT_SIZE;
    while(ubb-lbb>1){
      int m = (lbb+ubb)/2;
      if(rank(m)<=x) lbb = m;
      else ubb = m;
    }
    return lbb;
  }

  int select0(int x){
    if((block_size-1) * BIT_SIZE - rank((block_size-1) * BIT_SIZE) <= x) return -1;
    int lb = 0, ub = block_size-1;
    while(ub-lb>1){
      int m = (lb+ub)/2;
      if(s0_rank[m]<=x) lb = m;
      else ub = m;
    }
    int lbb = lb*BIT_SIZE, ubb = (lb+1)*BIT_SIZE;
    while(ubb-lbb>1){
      int m = (lbb+ubb)/2;
      if(m-rank(m)<=x) lbb = m;
      else ubb = m;
    }
    return lbb;
  }
};

class EliasFano {
  uint32_t n, u, lgn, lgun;
  FID low, high;

  //ceil(lg(x))
  uint32_t ceillg(double x){
    return ceil(log(x)/log(2));
  }

  uint32_t low_get(size_t i){
    uint32_t ret = 0;
    for(uint32_t p=0; p<lgun; p++){
      ret |= (low.access(lgun*i+p)) << (lgun-1-p);
    }
    return ret;
  }
  
public:
  EliasFano():low(0),high(0){}

  size_t size(){ return n; }
  
  void build(const std::vector<uint32_t>& v){
    if(v.size() == 0) return;
    n = v.size();
    u = v.back();
    lgun = ceillg(u/(double)n);
    lgn = ceillg(u+1) - lgun;

    {//low bits
      low.init(lgun * n);
      int pos = lgun * n - 1;
      for(auto it=v.rbegin(); it!=v.rend(); ++it){
        for(uint32_t b=0; b<lgun; b++){
          if(((*it)>>b)&1){
            low.set(pos);
          }
          pos--;
        }
      }
      low.finalize();
    }
    {//high bits
      high.init((1<<lgn) + n);
      std::vector<uint32_t> cnt(1<<lgn, 0);
      for(uint32_t x : v){
        cnt[(x>>lgun)]++;
      }
      int pos = 0;
      for(uint32_t x : cnt){
        for(uint32_t i=0; i<x; i++){
          high.set(pos);
          pos++;
        }
        pos++;
      }
      high.finalize();
    }
  }
  
  uint32_t access(int i){
    if(n == 0) return 0;
    uint32_t ret = 0;
    ret |= (high.select(i) - i) << lgun;
    ret |= low_get(i);
    return ret;
  }

  uint32_t successor(uint32_t x){
    if(n == 0) return 0;
    if(u < x) return 0;
    uint32_t h = (x>>lgun);
    uint32_t l = x&((1<<lgun)-1);
    uint32_t p1 = (h==0)?0:(high.select0(h-1)-h+1);
    uint32_t p2 = high.select0(h)-h;

    if(x <= access(p1)) return access(p1);
    
    while(p2-p1>1){
      uint32_t m = (p2+p1)/2;
      if(low_get(m) < l) p1 = m;
      else p2 = m;
    }

    return access(p2);
  }

  void dump(){
    std::cout << "n = " << n << ", ";
    std::cout << "u = " << u << ", ";
    std::cout << "ceil(lg(n)) = " << lgn << ", ";
    std::cout << "ceil(lg(u/n)) = " << lgun << std::endl;

    std::cout << "L = ";
    for(uint32_t i=0; i<lgun*n; i++){
      if(low.access(i)) std::cout << 1;
      else std::cout << 0;
    }
    std::cout << std::endl;
    std::cout << "H = ";
    for(uint32_t i=0; i<(1<<lgn)+n; i++){
      if(high.access(i)) std::cout << 1;
      else std::cout << 0;
    }
    std::cout << std::endl;
  }
};

int main(){
  //正の整数の単調増加列(1以上の整数, 同じ整数はなし, ソート済み)
  std::vector<uint32_t> v{3,4,7,13,14,15,21,43};
    
  EliasFano ef;
  ef.build(v);

  ef.dump();

  std::cout << "[access()]" << std::endl;
  for(size_t i=0; i<ef.size(); i++){
    std::cout << i << "\t" << v[i] << "\t" << ef.access(i) << std::endl;
  }

  std::cout << "[successor()]" << std::endl;
  for(uint32_t x=0; x<50; x++){
    std::cout << x << "\t" << ef.successor(x) << std::endl;
  }
  
  return 0;
}

結果

プレゼンの例と同じになっていることを確認。

n = 8, u = 43, ceil(lg(n)) = 3, ceil(lg(u/n)) = 3
L = 011100111101110111101011
H = 1110111010001000
[access()]
0       3       3
1       4       4
2       7       7
3       13      13
4       14      14
5       15      15
6       21      21
7       43      43
[successor()]
0       3
1       3
2       3
3       3
4       4
5       7
6       7
7       7
8       13
9       13
10      13
11      13
12      13
13      13
14      14
15      15
16      21
17      21
18      21
19      21
20      21
21      21
22      43
23      43
24      43
25      43
26      43
27      43
28      43
29      43
30      43
31      43
32      43
33      43
34      43
35      43
36      43
37      43
38      43
39      43
40      43
41      43
42      43
43      43
44      0
45      0
46      0
47      0
48      0
49      0

参考

2017-06-18

XBWを試す

はじめに

XBWをWaveletMatrixを使って、試しに実装してみた。

XBWとは

コード

XBWでの表のソートはそのまま文字列同士のソートをしている。
チェックは、コード中にあるように、適当にkey文字列とkey文字列じゃないのを生成してtrieと結果が一緒になるかだけ。
WaveletMatrixの方で理解のためにいくつか関数を書いているけど、rank()ぐらいしか使っていないので、それ以外はあまりVerifyできていない。
(g++はバージョン「5.4.0」、オプション「-std=gnu++1y -O2」で実行してる)

#include <vector>
#include <map>
#include <cstdint>
#include <algorithm>
#include <iostream>
#include <queue>
#include <random>

//完備辞書(Fully Indexable Dictionary)
// 【使いまわすときの注意】
// - 全部set()したら、最後にfinalize()を呼ぶこと
// - select(x)の実装で、xが要素数よりも多い場合-1を返す実装にしている
// - 32bit/64bit書き換えは、BIT_SIZE,BLOCK_TYPE,popcount,整数リテラルのサフィックスなどを書き換えること
class FID {
  static const int BIT_SIZE = 64;
  using BLOCK_TYPE = uint64_t;
  int size;
  int block_size;
  std::vector<BLOCK_TYPE> blocks;
  std::vector<int> s_rank;
public:
  //for BIT_SIZE == 32
  /*
  BLOCK_TYPE popcount(BLOCK_TYPE x){
    x = ((x & 0xaaaaaaaa) >> 1) + (x & 0x55555555);
    x = ((x & 0xcccccccc) >> 2) + (x & 0x33333333);
    x = ((x & 0xf0f0f0f0) >> 4) + (x & 0x0f0f0f0f);
    x = ((x & 0xff00ff00) >> 8) + (x & 0x00ff00ff);
    x = ((x & 0xffff0000) >> 16) + (x & 0x0000ffff);
    return x;
  }
   */
  //__builtin_popcount()
  
  //for BIT_SIZE == 64
  BLOCK_TYPE popcount(BLOCK_TYPE x){
    x = ((x & 0xaaaaaaaaaaaaaaaaULL) >> 1) + (x & 0x5555555555555555ULL);
    x = ((x & 0xccccccccccccccccULL) >> 2) + (x & 0x3333333333333333ULL);
    x = ((x & 0xf0f0f0f0f0f0f0f0ULL) >> 4) + (x & 0x0f0f0f0f0f0f0f0fULL);
    x = ((x & 0xff00ff00ff00ff00ULL) >> 8) + (x & 0x00ff00ff00ff00ffULL);
    x = ((x & 0xffff0000ffff0000ULL) >> 16) + (x & 0x0000ffff0000ffffULL);
    x = ((x & 0xffffffff00000000ULL) >> 32) + (x & 0x00000000ffffffffULL);
    return x;
  }
  //__builtin_popcountll()
public:
  FID(int size):
  size(size),
  block_size(((size + BIT_SIZE - 1) / BIT_SIZE) + 1),
  blocks(block_size, 0),
  s_rank(block_size, 0){}
  
  void set(int i){
    blocks[i/BIT_SIZE] |= 1ULL << (i%BIT_SIZE);
  }
  
  void finalize(){
    s_rank[0] = 0;
    for(int i=1; i<block_size; i++){
      s_rank[i] = s_rank[i-1] + popcount(blocks[i-1]);
    }
  }
  
  bool access(int i){
    return (blocks[i/BIT_SIZE] >> (i%BIT_SIZE)) & 1ULL;
  }

  //iより前のビットが立っている個数
  int rank(int i){
    BLOCK_TYPE mask = (1ULL << (i%BIT_SIZE)) - 1;
    return s_rank[i/BIT_SIZE] + popcount(mask & blocks[i/BIT_SIZE]);
  }

  //x番目にビットが立っている位置
  int select(int x){
    if(rank((block_size-1) * BIT_SIZE) <= x) return -1; //注意
    int lb = 0, ub = block_size-1;
    while(ub-lb>1){
      int m = (lb+ub)/2;
      if(s_rank[m]<=x) lb = m;
      else ub = m;
    }
    int lbb = lb*BIT_SIZE, ubb = (lb+1)*BIT_SIZE;
    while(ubb-lbb>1){
      int m = (lbb+ubb)/2;
      if(rank(m)<=x) lbb = m;
      else ubb = m;
    }
    return lbb;
  }
};

//ウェーブレット行列(Wavelet Matrix)
// 【使いまわすときの注意】
// - 全部set()したら、最後にfinalize()を呼ぶこと
// - 32bit/64bit書き換えは、BIT_SIZE,VAL_TYPEなどを書き換えること
class WaveletMatrix {
  static const int BIT_SIZE = 8;
  using VAL_TYPE = uint8_t;
  int size;
  std::vector<VAL_TYPE> v;
  std::vector<FID> matrix;
  std::vector<int> sep;

  struct mytuple {
    int b, s, e;
    mytuple(int b, int s, int e):b(b),s(s),e(e){}
    bool operator<(const mytuple& x) const {
      return e-s < x.e-x.s;
    }
  };
public:
  WaveletMatrix(int size):
  size(size),
  v(size, 0),
  matrix(BIT_SIZE, FID(size)),
  sep(BIT_SIZE, 0){}

  void set(int i, VAL_TYPE val){
    v[i] = val;
  }

  void finalize(){
    std::vector<VAL_TYPE> w(v.size(), 0);
    for(int b=BIT_SIZE-1; b>=0; b--){
      for(int i=0; i<size; i++){
        if((v[i] >> b) & 1ULL) matrix[b].set(i);
        else sep[b]++;
      }
      int b1=0, b2=sep[b];
      for(int i=0; i<size; i++){
        if((v[i] >> b) & 1ULL) w[b2++] = v[i];
        else w[b1++] = v[i];
      }
      for(int i=0; i<size; i++){
        v[i] = w[i];
      }
      matrix[b].finalize();
    }
  }

  //元の配列のi番目の要素
  VAL_TYPE access(int i){
    VAL_TYPE ret = 0;
    for(int b=BIT_SIZE-1; b>=0; b--){
      if(matrix[b].access(i)){
        i = sep[b] + matrix[b].rank(i);
        ret = (ret << 1) + 1ULL;
      }else{
        i = i - matrix[b].rank(i);
        ret = (ret << 1);
      }
    }
    return ret;
  }

  //[0,i)の範囲にxが何個存在するか
  int rank(int i, VAL_TYPE x){
    int lb = 0, ub = i;
    for(int b=BIT_SIZE-1; b>=0; b--){
      if((x >> b) & 1ULL){
        lb = matrix[b].rank(lb);
        ub = matrix[b].rank(ub);
        lb += sep[b];
        ub += sep[b];
      }else{
        lb = lb - matrix[b].rank(lb);
        ub = ub - matrix[b].rank(ub);                
      }
    }
    return ub - lb;
  }

  //i番目(0-index)のxが出現する位置
  int select(int i, VAL_TYPE x){
    int lb = 0, ub = size;
    while(ub-lb>1){
      int m = (lb+ub)/2;
      if(rank(m, x)<=i) lb = m;
      else ub = m;
    }
    return lb;
  }

  //[s,e)の範囲を切り出してソートしたときのn番目(0-index)の要素
  VAL_TYPE quantile(int s, int e, int n){
    for(int b=BIT_SIZE-1; b>=0; b--){
      int zn = (e - s) - (matrix[b].rank(e) - matrix[b].rank(s));
      if(zn <= n){
        s = matrix[b].rank(s);
        e = matrix[b].rank(e);
        s += sep[b];
        e += sep[b];
        n = n - zn;
      }else{
        s = s - matrix[b].rank(s);
        e = e - matrix[b].rank(e);                
      }      
    }
    return v[s];
  }

  //[s,e)の範囲で出現回数が多い数値順に、その数値と出現回数のTop-K
  std::vector<std::pair<VAL_TYPE,int>> top_k(int s, int e, int k){
    std::vector<std::pair<VAL_TYPE,int>> ret;
    std::priority_queue<mytuple> que;
    que.push(mytuple(BIT_SIZE-1,s,e));
    while(!que.empty()){
      mytuple q = que.top(); que.pop();
      int b = q.b, st = q.s, en = q.e;
      if(b < 0){
        ret.push_back(std::make_pair(v[st], en-st));
        if((int)ret.size() >= k) break;
      }else{
        int os = matrix[b].rank(st) + sep[b];
        int oe = matrix[b].rank(en) + sep[b];
        int zs = st - matrix[b].rank(st);
        int ze = en - matrix[b].rank(en);
        if(ze-zs > 0) que.push(mytuple(b-1,zs,ze));
        if(oe-os > 0) que.push(mytuple(b-1,os,oe));
      }
    }
    return ret;
  }

  //[s,e)の範囲でx<=c<yを満たすような数値cの合計出現数
  int rangefreq(int s, int e, VAL_TYPE x, VAL_TYPE y){
    int ret = 0;
    std::queue<std::pair<mytuple,VAL_TYPE>> que;
    que.push(std::make_pair(mytuple(BIT_SIZE-1,s,e),0));
    while(!que.empty()){
      std::pair<mytuple,VAL_TYPE> q = que.front(); que.pop();
      int b = q.first.b, st = q.first.s, en = q.first.e;
      VAL_TYPE mn = q.second;
      VAL_TYPE mx = q.second | ((b>=0)?0:((-1ULL) >> (BIT_SIZE - 1 - b)));
      if(x <= mn && mx < y){
        ret += en-st;
      }
      else if(mx < x || y <= mn){
        continue;
      }
      else {
        if(b < 0) continue;
        int os = matrix[b].rank(st) + sep[b];
        int oe = matrix[b].rank(en) + sep[b];
        int zs = st - matrix[b].rank(st);
        int ze = en - matrix[b].rank(en);
        if(ze-zs > 0) que.push(std::make_pair(mytuple(b-1,zs,ze), q.second));
        if(oe-os > 0) que.push(std::make_pair(mytuple(b-1,os,oe), q.second | (1ULL << b)));
      }
    }
    return ret;
  }
};

//XBW
class XBW {
  using VAL_TYPE = uint8_t;
  const char LAST_CHAR = (char)(0xff);
  
  struct Trie {
    bool flg;
    std::string rpp;
    std::map<char,Trie> next;
    Trie(){ flg = false; }
    void insert(const std::string &str){
      Trie *r = this;
      for(size_t i=0; i<str.length(); i++){
        r = &(r->next[str[i]]);
      }
      r->flg = true;
    }
    bool find(const std::string &str){
      Trie *r = this;
      for(size_t i=0; i<str.length(); i++){
        if(r->next.count(str[i]) == 0) return false;
        r = &(r->next[str[i]]);
      }
      return r->flg;
    }
  };
  struct ST {
    std::string children;
    std::string rpp;
    ST(std::string children, std::string rpp):children(children),rpp(rpp){}
    bool operator<(const ST& x) const { return rpp < x.rpp; }
  };

  Trie root;
  int xbw_size;
  std::string xbw_str;

  WaveletMatrix wm;
  FID fid;
  std::map<char,int> C;

  void build(std::vector<ST>& v){
    //XBWのサイズと文字の出現数のカウント
    std::map<char,int> cnt;
    xbw_size = 0;
    for(size_t i=0; i<v.size(); i++){
      xbw_size += v[i].children.length();
      for(size_t j=0; j<v[i].children.length(); j++){
        cnt[v[i].children[j]]++;
      }
    }

    //構築
    wm = WaveletMatrix(xbw_size);
    fid = FID(xbw_size);
    int idx = 0;
    for(size_t i=0; i<v.size(); i++){
      fid.set(idx);
      for(size_t j=0; j<v[i].children.length(); j++){
        wm.set(idx, (VAL_TYPE)(v[i].children[j]));
        idx++;
      }
    }
    wm.finalize();
    fid.finalize();

    C[(char)(0)] = 1;
    for(int i=1; i<256; i++){
      C[(char)(i)] = C[(char)(i-1)] + cnt[(char)(i-1)];
    }

    //trieの削除
    //root.next.clear();
  }

  int rank(int i, VAL_TYPE x){
    int pos = fid.select(i);
    return wm.rank(((pos<0)?xbw_size:pos),x);
  }  
public:
  XBW():wm(1),fid(1){}

  void add(const std::string& key){
    root.insert(key);
  }

  void finalize(){
    //chilren, reverse prefix pathの表の作成
    std::vector<ST> v;
    std::queue<Trie*> que;
    que.push(&root);
    while(!que.empty()){
      Trie *r = que.front(); que.pop();
      std::string children;
      std::string rpp = r->rpp;
      for(std::map<char,Trie>::iterator it=(r->next).begin(); it!=(r->next).end(); ++it){
        children += it->first;
        (it->second).rpp = rpp + it->first;
        que.push(&(it->second));
      }
      if(r->flg){
        children += LAST_CHAR;
      }
      std::reverse(rpp.begin(), rpp.end());
      v.push_back(ST(children, rpp));
    }

    std::sort(v.begin(), v.end());

    //XBW文字列の作成
    for(size_t i=0; i<v.size(); i++){
      if(i>0) xbw_str += "|";
      for(size_t j=0; j<v[i].children.length(); j++){
        if(v[i].children[j] == LAST_CHAR) xbw_str += "__LAST__"; //表示の都合上
        else xbw_str += v[i].children[j];
      }
    }

    build(v);
  }

  std::string get_xbw_string(){
    return xbw_str;
  }
  
  bool trie_find(const std::string& key){
    return root.find(key);
  }
  
  bool find(const std::string& key){
    int r = 0;
    for(size_t i=0; i<key.length(); i++){
      if(rank(r+1,(VAL_TYPE)(key[i])) - rank(r,(VAL_TYPE)(key[i])) == 0) return false;
      r = C[key[i]] + rank(r,(VAL_TYPE)(key[i]));
    }
    if(rank(r+1,(VAL_TYPE)(LAST_CHAR)) - rank(r,(VAL_TYPE)(LAST_CHAR)) == 0) return false;
    return true;
  }
};

int main(){
  std::mt19937 rnd{ std::random_device()() };
  std::map<std::string,int> keys, no_keys;
  int max_size = 500000; //keyとno_keysの要素数
  int max_len = 40; //文字列の長さの最大値
  
  int turn = 0;
  while(keys.size() < max_size || no_keys.size() < max_size){
    //generate key string

    int len = rnd() % max_len + 1;
    std::string key = "";
    for(int j=0; j<len; j++){
      key += (char)(' ' + rnd()%95);
    }

    if(keys.count(key) != 0 || no_keys.count(key) != 0) continue;
    
    if(turn == 0 && keys.size() < max_size){
      keys[key] = 1;
      turn = 1 - turn;
    }
    else if(turn == 1 && no_keys.size() < max_size){
      no_keys[key] = 1;
      turn = 1 - turn;
    }
  }
  std::cout << "key generated..." << std::endl;
  
  XBW xbw;
  for(const auto& x : keys){
    xbw.add(x.first);
  }
  xbw.finalize();
  std::cout << "XBW built..." << std::endl;
  
  //std::cout << "XBW = " << xbw.get_xbw_string() << std::endl;
  
  bool error = false;
  for(const auto& x : keys){
    if(xbw.trie_find(x.first) != xbw.find(x.first)){
      std::cout << "error : " << x.first << std::endl;
      error = true;
    }
  }
  for(const auto& x : no_keys){
    if(xbw.trie_find(x.first) != xbw.find(x.first)){
      std::cout << "error : " << x.first << std::endl;
      error = true;
    }
  }
  if(!error) std::cout << "no error" << std::endl;
  
  return 0;
}


確認のために、解説ページのTrieからXBWを出力してみる。
main()の内容を以下のように変更すると確認できる。

int main(){
  std::vector<std::string> v{"to","tea","ten","i","in","inn","we"};
  XBW xbw;
  for(const auto& x : v){
    xbw.add(x);
  }
  xbw.finalize();
  std::cout << xbw.get_xbw_string() << std::endl;
  return 0;
}

結果。

itw|__LAST__|an|__LAST__|n__LAST__|__LAST__|n__LAST__|__LAST__|__LAST__|eo|e

参考

2017-02-23

Kneser-Ney smoothingで遊ぶ

はじめに

100-nlp-papersで紹介されてた一番最初の論文に、クナイザーネイスムージングのスッキリな実装が載っていたので書いてみる。

Joshua Goodman: A bit of progress in language modeling, MSR Technical Report, 2001.

Kneser-Ney smoothingとは

  • 言語モデルのスムージング(平滑化)手法一種で、高い言語モデル性能を実現している
    • ニューラル言語モデルでも比較によく使われる
  • アイデアとしては「(n-1)-gramが出現した文脈での異なり数」を使うこと
    • 頻度を使うと、高頻度なn-gramではその(n-1)-gramも多くなってしまうため、特定文脈でしかでないような(n-1)-gramに対しても高い確率値ことになっていて、歪んだ結果になってしまう
      • 「San Francisco」の頻度が多いと「Francisco」の頻度も高くなるが、P(Francisco|on)とかはあまり出現しないので低くなってほしいところ、「Franciscoの頻度」を使って確率値を推定すると高くなってしまう
    • 頻度ではなく、異なり数で(n-1)-gramの確率を推定することで、補正する
  • 上のレポートでは、Interpolatedな補間方法での実装例を紹介している
    • back-offな方法も考えらえる
    • discount(割引値)パラメータをn-gramごとに分けた方法は「modified Kneser-Ney smoothing」と呼ばれている

UNKの扱い

  • レポートのAppendixのFigure17と18はそのままだと学習データに出現しない単語UNKが出てくると、unigramが0なので、確率も0になってしまう
  • レポートの8ページ目では、一様分布1/|V|(Vは語彙集合)を使ってスムージングしてこれを避けると紹介されている
  • λはどうするの?というのは、以下のページで議論されているように、λ(ε)と考えると、「λ==discountパラメータ」としてもよいかなと思うので、コードではそのようにした

コード

UNKのためにちょっと修正した。あんまりちゃんとチェックできていないけど、それっぽい数値を返しているのでおそらく大丈夫。

#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <cmath>
#include <unordered_map>

class InterpolatedKneserNeyTrigram {
  const std::string delim = "\t";
  double discount; //(0,1)の範囲で指定する
  std::unordered_map<std::string,int> TD, TN, TZ, BD, BN, BZ, UN;
  int UD;
public:
  InterpolatedKneserNeyTrigram():discount(0.1),UD(0){}
  InterpolatedKneserNeyTrigram(double d):discount(d),UD(0){}

  //ファイルに書き出し
  void save(const std::string& filename){
    std::ofstream fout(filename);
    if(!fout){ std::cerr << "cannot open file" << std::endl; return; }
    fout << discount << std::endl;
    fout << TD.size() << std::endl;
    std::unordered_map<std::string,int>::iterator it;
    for(it=TD.begin(); it!=TD.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << TN.size() << std::endl;
    for(it=TN.begin(); it!=TN.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << TZ.size() << std::endl;
    for(it=TZ.begin(); it!=TZ.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << BD.size() << std::endl;
    for(it=BD.begin(); it!=BD.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << BN.size() << std::endl;
    for(it=BN.begin(); it!=BN.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << BZ.size() << std::endl;
    for(it=BZ.begin(); it!=BZ.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << UN.size() << std::endl;
    for(it=UN.begin(); it!=UN.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << UD << std::endl;    
    fout.close();
  }
  //ファイルから読み込み
  void load(const std::string& filename){
    std::ifstream fin(filename);
    if(!fin){ std::cerr << "cannot open file" << std::endl; return; }
    fin >> discount;
    int td, tn, tz, bd, bn, bz, un;
    std::string s;
    int c;
    fin >> td;
    for(int i=0; i<td; i++){ getline(fin,s); getline(fin, s); fin >> c; TD[s] = c; }
    fin >> tn;
    for(int i=0; i<tn; i++){ getline(fin,s); getline(fin, s); fin >> c; TN[s] = c; }
    fin >> tz;
    for(int i=0; i<tz; i++){ getline(fin,s); getline(fin, s); fin >> c; TZ[s] = c; }
    fin >> bd;
    for(int i=0; i<bd; i++){ getline(fin,s); getline(fin, s); fin >> c; BD[s] = c; }
    fin >> bn;
    for(int i=0; i<bn; i++){ getline(fin,s); getline(fin, s); fin >> c; BN[s] = c; }
    fin >> bz;
    for(int i=0; i<bz; i++){ getline(fin,s); getline(fin, s); fin >> c; BZ[s] = c; }
    fin >> un;
    for(int i=0; i<un; i++){ getline(fin,s); getline(fin, s); fin >> c; UN[s] = c; }
    fin >> UD;
    fin.close();
  }

  void set_discount(double d){ discount = d; }
  double get_discount() const { return discount; }
  
  void add_sentence(const std::vector<std::string>& sentence){
    std::string w2 = "", w1 = "";

    for(size_t i=0; i<sentence.size(); i++){
      std::string w0 = sentence[i];
      TD[ w2 + delim + w1 ]++;
      if(TN[ w2 + delim + w1 + delim + w0 ]++ == 0){
        TZ[ w2 + delim + w1 ]++;

        BD[ w1 ]++;
        if(BN[ w1 + delim + w0 ]++ == 0){
          BZ[ w1 ]++;

          UD++;
          UN[ w0 ]++;
        }
      }
      w2 = w1;
      w1 = w0;
    }
  }

  double prob(const std::vector<std::string>& sentence){
    std::string w2 = "", w1 = "";
    double ret = 0;

    for(size_t i=0; i<sentence.size(); i++){
      std::string w0 = sentence[i];
      double prob = 0;

      //そのままだとUN[w0]==0のときprob==0になるため、1/|V|を使うように変更
      double uniform = 1.0 / UN.size();
      double unigram = 0.0;
      if(UN.count( w0 ) > 0){
        unigram = (UN[ w0 ] - discount) / (double)UD;
      }
      unigram += discount * uniform;
      if(BD.count( w1 ) > 0){
        double bigram = 0;
        if(BN.count( w1 + delim + w0 ) > 0){
          bigram = (BN[ w1 + delim + w0 ] - discount) / BD[ w1 ];
        }
        bigram += BZ[ w1 ] * discount / BD[ w1 ] * unigram;

        if(TD.count( w2 + delim + w1 ) > 0){
          double trigram = 0;
          if(TN.count( w2 + delim + w1 + delim + w0 ) > 0){
            trigram = (TN[ w2 + delim + w1 + delim + w0 ] - discount) / TD[ w2 + delim + w1 ];
          }
          trigram += TZ[ w2 + delim + w1 ] * discount / TD[ w2 + delim + w1 ] * bigram;
          prob = trigram;
        }else{
          prob = bigram;
        }
      }else{
        prob = unigram;
      }
      ret += log(prob);      
      w2 = w1;
      w1 = w0;
    }
    return ret;
  }
};


int main(){
  InterpolatedKneserNeyTrigram lm;  
  std::vector< std::vector<std::string> > train_v, valid_v;
  
  {//ファイルの読み込み
    std::ifstream trainfs("train.txt");
    std::ifstream validfs("valid.txt");
    std::string w;
    std::vector<std::string> tmp;
    while(trainfs >> w){
      tmp.push_back(w);
      if(w == "EOS"){
        train_v.push_back(tmp);
        tmp.clear();
      }
    }
    
    tmp.clear();
    while(validfs >> w){
      tmp.push_back(w);
      if(w == "EOS"){
        valid_v.push_back(tmp);
        tmp.clear();
      }
    }
  }
  
  {//学習用の文を全部入れる
    for(size_t i=0; i<train_v.size(); i++){
      lm.add_sentence(train_v[i]);
    }
  }
  
  {//よさそうなdを探す
    double best = log(0), best_d = 0;
    double prec = 0.001;
    for(double d=prec; d<1; d+=prec){
      lm.set_discount(d);
      double logq = 0.0;
      for(size_t i=0; i<valid_v.size(); i++){
        logq += lm.prob(valid_v[i]);
      }
      std::cerr << d << "\t" << logq << std::endl;
      if(best < logq){
        best = logq;
        best_d = d;
      }
    }
    lm.set_discount(best_d);
    std::cerr << "best: " << best << " (d = " << best_d << ")" << std::endl;
  }

  lm.save("lm.data");

  return 0;
}

実験

データの準備

「坊ちゃん」の言語モデルを作ってみる。
青空文庫から「坊ちゃん」のテキストを取得し、「≪≫」などで囲まれた部分を削除したものを用意。
全部で470行で、10行ごとをdiscount係数確認用にする。

さらにそれを、mecab+ipadicで1行1単語にした以下のようなテキストを準備する。
(学習用train.txt 424行分、確認用valid.txt 46行分)

親譲り
の
無鉄砲
で
小
供
の
時
から
損
...
と
答え
た
。
EOS
親類
の
もの
...

(1行分の終わりには「EOS」を含む)

最適なパラメータの探索

確認用のデータで一番尤度が高くなるパラメータを採用する。

f:id:jetbead:20170222021821p:image
最適なのは、discount=0.897のとき、対数尤度が-29116ぐらい。


事例

いくつかの例で確率を見てみる。

s=「親譲り の 無鉄砲 だ EOS」 → log(P(s)) = -23.7238
s=「親譲り の ブレイブハート だ EOS」 → log(P(s)) = -33.8758
s=「吾輩 は 猫 だ EOS」 → log(P(s)) = -36.8097
s=「無鉄砲 な フレンズ だ EOS」 → log(P(s)) = -38.3098


参考