ラティスのNbestを求める
はじめに
形態素解析とかの解析時に使うラティス(形態素候補をグラフにしたもの)のうち、1番ベストな解だけが欲しい場合が多いが、2番目以降の解も欲しい場合がある。
N番目までの解を効率よく求める方法があり、使いたいケースが出てきたのに書いてみる。
Nbestとは
- 解析したときに、スコアが良い順に、第一候補(1best)の解だけでなく、第二候補、第三候補・・・の解が考えられるとき、第N候補までのことをNbestという
- Nbest解
前向きDP後ろ向きA*アルゴリズム
- http://ci.nii.ac.jp/naid/110002725063
- ラティスにおいて、効率よくNbestを求める方法が提案されている
- 最初に、1bestを求める方法と同じようにスタートノードからあるノードまでの最小コストhをすべてのノードについて求めておく
- ゴールノードからスタートノードに向かって、今考えているノードから(最適とは限らない、これまで取ったパスによって変わる)ゴールまでのコストgを求めていく
- 前側のコスト合計hと後ろ側のコスト合計gを使うと、スタートから今考えているノードを通って(今まで通ってきた最適とは限らない)ゴールまでのパスの合計値fが求められる
- fを使って小さい順に考えるノードを増やしながら、スタートノードにたどり着くパスを見つけると、Nbestの解を順次求められる
書いてみたけどわかりにくい。
徳永さんの「日本語入力を支える技術」の中で丁寧に紹介されている。
コード
スタートノード(BOS)とゴールノード(EOS)が1つずつあり、必ずスタートノードからゴールノードにパスを持つようなDAGであれば大丈夫だと思うので、それで書いてみる。
(雑なので、パスのノードを別に持ったり、トポロジカル順序を求めてる無駄がある・・・)
#include <iostream> #include <vector> #include <queue> #include <algorithm> #include <map> #include <set> #define INF 99999999 struct Node { int id; int w; //id,w std::map<int,int> forward; std::map<int,int> backward; int h; }; struct Path { int id; int g; int f; std::vector<int> rev; //現在idからゴールノードまでの逆順パス(注意:無駄にメモリを使っている) }; bool operator>(const Path& a, const Path& b){ return a.f > b.f; } //スタートとゴールとなるノードがそれぞれ1つずつあり、ゴールノード以外で終わるようなノードが無いDAG class stDAG { int n; //ノード数 int s, t; //スタートノード番号、ゴールノード番号 std::vector<Node> nodes; //ノード情報 //トポロジカル順序を求めるため bool topological_sort(std::vector<int>& order){ std::vector<int> color(n); for(int i=0; i<n; i++){ if(!color[i] && !topological_visit(i, order, color)) return false; } std::reverse(order.begin(), order.end()); return true; } bool topological_visit(int v, std::vector<int>& order, std::vector<int>& color){ color[v] = 1; for(std::map<int,int>::iterator itr = nodes[v].forward.begin(); itr != nodes[v].forward.end(); ++itr){ if(color[itr->first] == 2) continue; if(color[itr->first] == 1) return false; if(!topological_visit(itr->first, order, color)) return false; } order.push_back(v); color[v] = 2; return true; } public: //ノード数、スタートノードID、ゴールノードID stDAG(int n, int s, int t): n(n), nodes(n), s(s), t(t) { for(int i=0; i<n; i++){ nodes[i].id = i; nodes[i].h = INF; } } //ノード情報の追加 void add_node(int id, int w){ nodes[id].w = w; } //エッジ情報の追加 void add_edge(int id, int next_id, int w){ nodes[id].forward[next_id] = w; nodes[next_id].backward[id] = w; } //前向きDP void forward_dp(){ std::vector<int> order; if(!topological_sort(order)){ std::cerr << "topological sort error" << std::endl; return; } if(nodes[order[0]].id != s || nodes[order[n-1]].id != t){ std::cerr << "s,t id error" << std::endl; return; } nodes[order[0]].h = 0; for(int i=0; i<n; i++){ int id = order[i]; //std::cout << id << "\t" << nodes[id].h << std::endl; for(std::map<int,int>::iterator itr = nodes[id].forward.begin(); itr != nodes[id].forward.end(); ++itr){ int next_id = itr->first; int next_w = nodes[next_id].w; int edge_w = itr->second; nodes[next_id].h = std::min(nodes[next_id].h, nodes[id].h + edge_w + next_w); } } } //後向きA* void backward_astar(int N){ int no = 1; std::priority_queue<Path, std::vector<Path>, std::greater<Path> > que; Path path; path.id = t; path.g = 0; path.f = nodes[t].h; path.rev.push_back(t); que.push(path); while(!que.empty()){ path = que.top(); que.pop(); int id = path.id; int bestf = path.f; if(id == s){ std::cout << no << "best: "; for(int i=path.rev.size()-1; i>=0; i--){ std::cout << path.rev[i] << " "; } std::cout << " => " << path.f + nodes[s].w << std::endl; if(no == N) return; no++; } for(std::map<int,int>::iterator itr = nodes[id].backward.begin(); itr != nodes[id].backward.end(); ++itr){ int gg = nodes[id].w + itr->second + path.g; int ff = gg + nodes[itr->first].h; Path new_path; new_path.id = itr->first; new_path.g = gg; new_path.f = ff; new_path.rev = path.rev; new_path.rev.push_back(itr->first); que.push(new_path); } } } }; int main(){ //graph1 stDAG dag(6, 0, 5); dag.add_node(0, 0); dag.add_node(1, 1); dag.add_node(2, 2); dag.add_node(3, 3); dag.add_node(4, 4); dag.add_node(5, 5); dag.add_edge(0, 1, 1); dag.add_edge(0, 2, 2); dag.add_edge(1, 3, 2); dag.add_edge(1, 4, 1); dag.add_edge(2, 3, 1); dag.add_edge(2, 4, 2); dag.add_edge(3, 5, 2); dag.add_edge(4, 5, 1); dag.forward_dp(); dag.backward_astar(4); //graph2 /* stDAG dag(7, 0, 6); dag.add_node(0, 3); dag.add_node(1, 1); dag.add_node(2, 2); dag.add_node(3, 1); dag.add_node(4, 2); dag.add_node(5, 3); dag.add_node(6, 2); dag.add_edge(0, 1, 2); dag.add_edge(0, 2, 1); dag.add_edge(1, 3, 3); dag.add_edge(2, 3, 1); dag.add_edge(3, 4, 1); dag.add_edge(3, 5, 2); dag.add_edge(4, 6, 4); dag.add_edge(5, 6, 1); dag.forward_dp(); dag.backward_astar(4); */ return 0; }