ラティスのNbestを求める

はじめに

形態素解析とかの解析時に使うラティス(形態素候補をグラフにしたもの)のうち、1番ベストな解だけが欲しい場合が多いが、2番目以降の解も欲しい場合がある。
N番目までの解を効率よく求める方法があり、使いたいケースが出てきたのに書いてみる。

Nbestとは

  • 解析したときに、スコアが良い順に、第一候補(1best)の解だけでなく、第二候補、第三候補・・・の解が考えられるとき、第N候補までのことをNbestという
  • Nbest解
前向きDP後ろ向きA*アルゴリズム
  • http://ci.nii.ac.jp/naid/110002725063
  • ラティスにおいて、効率よくNbestを求める方法が提案されている
  • 最初に、1bestを求める方法と同じようにスタートノードからあるノードまでの最小コストhをすべてのノードについて求めておく
    • 1bestの時は、ゴールノードからスタートノードに向かって、最小コストが選ばれるノードを逆に辿る事で求められた
    • Nbestの時は、「A*アルゴリズム」の考え方を利用して求める
  • ゴールノードからスタートノードに向かって、今考えているノードから(最適とは限らない、これまで取ったパスによって変わる)ゴールまでのコストgを求めていく
  • 前側のコスト合計hと後ろ側のコスト合計gを使うと、スタートから今考えているノードを通って(今まで通ってきた最適とは限らない)ゴールまでのパスの合計値fが求められる
  • fを使って小さい順に考えるノードを増やしながら、スタートノードにたどり着くパスを見つけると、Nbestの解を順次求められる

書いてみたけどわかりにくい。
徳永さんの「日本語入力を支える技術」の中で丁寧に紹介されている。

コード

スタートノード(BOS)とゴールノード(EOS)が1つずつあり、必ずスタートノードからゴールノードにパスを持つようなDAGであれば大丈夫だと思うので、それで書いてみる。
(雑なので、パスのノードを別に持ったり、トポロジカル順序を求めてる無駄がある・・・)

#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>
#include <map>
#include <set>

#define INF 99999999

struct Node {
  int id;
  int w;
  //id,w
  std::map<int,int> forward;
  std::map<int,int> backward;

  int h;
};

struct Path {
  int id;
  int g;
  int f;
  std::vector<int> rev; //現在idからゴールノードまでの逆順パス(注意:無駄にメモリを使っている)
};
bool operator>(const Path& a, const Path& b){
  return a.f > b.f;
}

//スタートとゴールとなるノードがそれぞれ1つずつあり、ゴールノード以外で終わるようなノードが無いDAG
class stDAG {
  int n; //ノード数
  int s, t; //スタートノード番号、ゴールノード番号
  std::vector<Node> nodes; //ノード情報

  //トポロジカル順序を求めるため
  bool topological_sort(std::vector<int>& order){
    std::vector<int> color(n);
    for(int i=0; i<n; i++){
      if(!color[i] && !topological_visit(i, order, color)) return false;
    }
    std::reverse(order.begin(), order.end());
    return true;
  }
  bool topological_visit(int v, std::vector<int>& order, std::vector<int>& color){
    color[v] = 1;
    for(std::map<int,int>::iterator itr = nodes[v].forward.begin();
        itr != nodes[v].forward.end();
        ++itr){
      if(color[itr->first] == 2) continue;
      if(color[itr->first] == 1) return false;
      if(!topological_visit(itr->first, order, color)) return false;
    }
    order.push_back(v);
    color[v] = 2;
    return true;
  }

public:
  //ノード数、スタートノードID、ゴールノードID
  stDAG(int n, int s, int t):
    n(n),
    nodes(n),
    s(s),
    t(t)
  {
    for(int i=0; i<n; i++){
      nodes[i].id = i;
      nodes[i].h = INF;
    }
  }

  //ノード情報の追加
  void add_node(int id, int w){
    nodes[id].w = w;
  }
  //エッジ情報の追加
  void add_edge(int id, int next_id, int w){
    nodes[id].forward[next_id] = w;
    nodes[next_id].backward[id] = w;
  }
  
  //前向きDP
  void forward_dp(){
    std::vector<int> order;
    if(!topological_sort(order)){
      std::cerr << "topological sort error" << std::endl;
      return;
    }
    if(nodes[order[0]].id != s || nodes[order[n-1]].id != t){
      std::cerr << "s,t id error" << std::endl;
      return;
    }

    nodes[order[0]].h = 0;
    for(int i=0; i<n; i++){
      int id = order[i];
      //std::cout << id << "\t" << nodes[id].h << std::endl;
      for(std::map<int,int>::iterator itr = nodes[id].forward.begin();
          itr != nodes[id].forward.end();
          ++itr){
        int next_id = itr->first;
        int next_w = nodes[next_id].w;
        int edge_w = itr->second;
        nodes[next_id].h = std::min(nodes[next_id].h, nodes[id].h + edge_w + next_w);
      }      
    }
  }

  //後向きA*
  void backward_astar(int N){
    int no = 1;
    std::priority_queue<Path, std::vector<Path>, std::greater<Path> > que;

    Path path;
    path.id = t;
    path.g = 0;
    path.f = nodes[t].h;
    path.rev.push_back(t);
    que.push(path);

    while(!que.empty()){
      path = que.top(); que.pop();
      int id = path.id;
      int bestf = path.f;

      if(id == s){
        std::cout << no << "best: ";
        for(int i=path.rev.size()-1; i>=0; i--){
          std::cout << path.rev[i] << " ";
        }
        std::cout << " => " << path.f + nodes[s].w << std::endl;
        if(no == N) return;

        no++;
      }

      for(std::map<int,int>::iterator itr = nodes[id].backward.begin();
          itr != nodes[id].backward.end();
          ++itr){
        int gg = nodes[id].w + itr->second + path.g;
        int ff = gg + nodes[itr->first].h;

        Path new_path;
        new_path.id = itr->first;
        new_path.g = gg;
        new_path.f = ff;
        new_path.rev = path.rev;
        new_path.rev.push_back(itr->first);
        que.push(new_path);
      }
    }
  }
};


int main(){
  //graph1
  stDAG dag(6, 0, 5);
  dag.add_node(0, 0);
  dag.add_node(1, 1);
  dag.add_node(2, 2);
  dag.add_node(3, 3);
  dag.add_node(4, 4);
  dag.add_node(5, 5);
  dag.add_edge(0, 1, 1);
  dag.add_edge(0, 2, 2);
  dag.add_edge(1, 3, 2);
  dag.add_edge(1, 4, 1);
  dag.add_edge(2, 3, 1);
  dag.add_edge(2, 4, 2);
  dag.add_edge(3, 5, 2);
  dag.add_edge(4, 5, 1);
  dag.forward_dp();
  dag.backward_astar(4);

  //graph2
  /*
  stDAG dag(7, 0, 6);
  dag.add_node(0, 3);
  dag.add_node(1, 1);
  dag.add_node(2, 2);
  dag.add_node(3, 1);
  dag.add_node(4, 2);
  dag.add_node(5, 3);
  dag.add_node(6, 2);
  dag.add_edge(0, 1, 2);
  dag.add_edge(0, 2, 1);
  dag.add_edge(1, 3, 3);
  dag.add_edge(2, 3, 1);
  dag.add_edge(3, 4, 1);
  dag.add_edge(3, 5, 2);
  dag.add_edge(4, 6, 4);
  dag.add_edge(5, 6, 1);
  dag.forward_dp();
  dag.backward_astar(4);
  */

  return 0;
}

結果

グラフ1

$ ./a.out
1best: 0 1 4 5  => 13
2best: 0 1 3 5  => 14
3best: 0 2 3 5  => 15
4best: 0 2 4 5  => 16
グラフ2

$ ./a.out
1best: 0 2 3 5 6  => 16
2best: 0 2 3 4 6  => 17
3best: 0 1 3 5 6  => 18
4best: 0 1 3 4 6  => 19

とりあえず合ってそうなので、大丈夫そう。