LDA の Collapsed Gibbs サンプリングの全条件付分布を導出してみる

にも出てくるこの式

  • P(z_{mn}=k|\boldsymbol{z}^{-mn}, \boldsymbol{w})\;\propto\;(n_{mk}^{-mn}+\alpha)\cdot\frac{n_{tk}^{-mn}+\beta}{n_{k}^{-mn}+V\beta}

を導出してみる。


この式は LDA の Collapsed Gibbs sampling で使う全条件付分布(full conditional)。
もし普通のギブスサンプリングだったら、観測されていない全ての確率変数、つまり Z だけではなくθやφについても同様に全条件付分布を構成して、そこからサンプリングを繰り返すことが必要になる。*1
そこで、θとφについては積分消去してしまうことで、Z だけをサンプリングすればよいようにしたのが Collapsed Gibbs sampling。"collapsed" は積分消去して「つぶした」ということと、素の Gibbs sampling から「崩した」ということと、両方かかっているんだろうか?


導出に必要な道具は次の2つ。

  • ガンマ関数
    • \Gamma(n+1)=n\Gamma(n)
  • 多項ベータ関数の積分
    • \int{\prod_i x_i^{a_i-1}d\boldsymbol{x}}=\frac{\prod_i\Gamma(a_i)}{\Gamma(\sum_i a_i)}, \hspace{1em}\text{where}\;x_i\geq0,\sum_i x_i=1


LDA に限らず、ディリクレ分布の積分にはここらへんが大活躍。ディリクレ分布のパラメータが0のとき - 木曜不足 でも使っている。


後先になったが、登場する確率変数を整理しておく。
ちなみに m はドキュメントのインデックス、n はドキュメント内の単語のインデクス、k はトピック。肩に -mn が付くときは「m,n 要素を除外して考える」を意味する。

  • θ:ドキュメント-トピック分布, P(\boldsymbol{\theta}_m;\boldsymbol{\alpha})=\text{Dir}(\boldsymbol{\alpha})
  • φ:トピック-単語分布, P(\boldsymbol{\varphi}_k;\boldsymbol{\beta})=\text{Dir}(\boldsymbol{\beta})
  • z:単語のトピック, P(z_{mn}|\boldsymbol{\theta}_m)=\text{Multi}(\boldsymbol{\theta}_m)
  • w:単語, P(w_{mn}|z_{mn}=k,\boldsymbol{\varphi})=\text{Multi}(\boldsymbol{\varphi}_k)


この記号のもと、天下りになるがまずは \int P(\boldsymbol{\theta}_m)\prod_n P(z_{mn}|\boldsymbol{\theta}_m) d\boldsymbol{\theta}_m を計算しておこう。


\int P(\boldsymbol{\theta}_m)\prod_n P(z_{mn}|\boldsymbol{\theta}_m)d\boldsymbol{\theta}_m
=\int\frac{\Gamma\left(\sum_k \alpha_k\right)}{\prod_k\Gamma(\alpha_k)}\prod_k\theta_{mk}^{\alpha_k-1}\prod_n\theta_{mz_{mn}}d\boldsymbol{\theta}_m
=\frac{\Gamma\left(\sum_k \alpha_k\right)}{\prod_k\Gamma(\alpha_k)}\int\prod_k\theta_{mk}^{\alpha_k-1}\prod_k\theta_{mk}^{n_{mk}}d\boldsymbol{\theta}_m
=\frac{\Gamma\left(\sum_k \alpha_k\right)}{\prod_k\Gamma(\alpha_k)}\int\prod_k\theta_{mk}^{n_{mk}+\alpha_k-1}d\boldsymbol{\theta}_m
=\frac{\Gamma\left(\sum_k \alpha_k\right)}{\prod_k\Gamma(\alpha_k)}\;\cdot\;\frac{\prod_k\Gamma(n_{mk}+\alpha_k)}{\Gamma\left(n_{m}+\sum_k \alpha_k\right)}


\prod_n\theta_{mz_{mn}}\prod_k\theta_{mk}^{n_{mk}} に書き換えるところは、k=z_mn となる z_mn の個数 n_mk だけ θ_mk が現れることを考えるとわかる。
同様の計算が \int P(\boldsymbol{\varphi}_k)\prod_{m,n:z_{mn}=k} P(w_{mn}|z_{mn}=k,\boldsymbol{\varphi}_k) d\boldsymbol{\varphi}_k積分にも行うことが出来て、


\int P(\boldsymbol{\varphi}_k)\prod_{m,n:z_{mn}=k} P(w_{mn}|z_{mn}=k,\boldsymbol{\varphi}_k) d\boldsymbol{\varphi}_k=\frac{\Gamma\left(\sum_t \beta_t\right)}{\prod_t\Gamma(\beta_t)}\;\cdot\;\frac{\prod_t\Gamma(n_{tk}+\beta_t)}{\Gamma\left(n_{k}+\sum_t \beta_t\right)}


が得られる。ただし t は語彙のインデックスであり、n_tk はトピック k を持つ単語 t の個数としている。


いよいよ全条件付分布を計算。ターゲットとなるトピックのインデックスを z_ji、トピックを k' とし、k'を動かしても変わらない項を比例定数としてじゃんじゃん削っていくと、


P(z_{ji}=k'|\boldsymbol{z}^{-ji}, \boldsymbol{w})
=\frac{P(z_{ji}=k',\boldsymbol{z}^{-ji},\boldsymbol{w})}{P(\boldsymbol{z}^{-ji}, \boldsymbol{w})}
\propto P(z_{ji}=k',\boldsymbol{z}^{-ji},\boldsymbol{w})
=\int P(z_{ji}=k',\boldsymbol{z}^{-ji},\boldsymbol{w},\boldsymbol{\theta},\boldsymbol{\varphi})d\boldsymbol{\theta}d\boldsymbol{\varphi}
=\int \prod_m P(\boldsymbol{\theta}_m)\prod_{mn} P(z_{mn}|\boldsymbol{\theta}_m)\prod_k P(\boldsymbol{\varphi}_k)\prod_{mn} P(w_{mn}|z_{mn},\boldsymbol{\varphi})d\boldsymbol{\theta}d\boldsymbol{\varphi}
=\prod_m\left\{\int P(\boldsymbol{\theta}_m)\prod_n P(z_{mn}|\boldsymbol{\theta}_m)d\boldsymbol{\theta}_m\right\}\;\prod_k\left\{\int P(\boldsymbol{\varphi}_k)\prod_{m,n:z_{mn}=k} P(w_{mn}|z_{mn}=k,\boldsymbol{\varphi}_k) d\boldsymbol{\varphi}_k\right\}
=\prod_m\frac{\Gamma\left(\sum_k \alpha_k\right)}{\prod_k\Gamma(\alpha_k)}\;\cdot\;\frac{\prod_k\Gamma(n_{mk}+\alpha_k)}{\Gamma\left(n_{m}+\sum_k \alpha_k\right)} \prod_k\frac{\Gamma\left(\sum_t \beta_t\right)}{\prod_t\Gamma(\beta_t)}\;\cdot\;\frac{\prod_t\Gamma(n_{tk}+\beta_t)}{\Gamma\left(n_{k}+\sum_t \beta_t\right)}
\propto\prod_k\Gamma(n_{jk}+\alpha_k)\prod_k\frac{\prod_t\Gamma(n_{tk}+\beta_t)}{\Gamma\left(n_{k}+\sum_t \beta_t\right)}


いったん k' が表面上消えてしまっているが、z_{ji}=k' や w_{ji} に関係ある項のみを抽出するよう変形していくと、


=\Gamma(n_{j{k'}}+\alpha_{k'})\frac{\prod_t\Gamma(n_{t{k'}}+\beta_t)}{\Gamma\left(n_{{k'}}+\sum_t \beta_t\right)}\prod_{k\neq k'}\Gamma(n_{jk}+\alpha_k)\prod_{k\neq k'}\frac{\prod_t\Gamma(n_{tk}+\beta_t)}{\Gamma\left(n_{k}+\sum_t \beta_t\right)}
\propto\Gamma(n_{j{k'}}+\alpha_{k'})\frac{\Gamma(n_{{w_{ji}}{k'}}+\beta_{w_{ji}})\prod_{t\neq{w_{ji}}}\Gamma(n_{t{k'}}+\beta_t)}{\Gamma\left(n_{{k'}}+\sum_t \beta_t\right)}
\propto\Gamma(n_{j{k'}}+\alpha_{k'})\frac{\Gamma(n_{{w_{ji}}{k'}}+\beta_{w_{ji}})}{\Gamma\left(n_{{k'}}+\sum_t \beta_t\right)}


ここまで簡単になった。
これでもう十分簡単に見えるかもしれないが、さらに共通の比例項がある。
n_{jk'} などを考えるときに、z_{ji}=k' を除外してみると、n_{jk'}=n_{jk'}^{-ji}+1 という関係式が得られる。


=\Gamma(n_{j{k'}}^{-ji}+\alpha_{k'}+1)\frac{\Gamma(n_{{w_{ji}}{k'}}^{-ji}+\beta_{w_{ji}}+1)}{\Gamma\left(n_{{k'}}^{-ji}+\sum_t \beta_t+1\right)}


さらに \Gamma(n+1)=n\Gamma(n) を使うと、


=(n_{j{k'}}^{-ji}+\alpha_{k'})\Gamma(n_{j{k'}}^{-ji}+\alpha_{k'})\frac{(n_{{w_{ji}}{k'}}^{-ji}+\beta_{w_{ji}})\Gamma(n_{{w_{ji}}{k'}}^{-ji}+\beta_{w_{ji}})}{\left(n_{{k'}}^{-ji}+\sum_t \beta_t\right)\Gamma\left(n_{{k'}}^{-ji}+\sum_t \beta_t\right)}


ここで n_{jk'}^{-ji} は z_{ji}=k' によらず変化しないので、


\propto(n_{j{k'}}^{-ji}+\alpha_{k'})\frac{n_{{w_{ji}}{k'}}^{-ji}+\beta_{w_{ji}}}{n_{{k'}}^{-ji}+\sum_t \beta_t}


このαとβが対称(全てのα_kとβ_tがそれぞれ同じ値)のとき、冒頭の全条件付分布が得られる。めでたし。


というわけで、次回はここで練習したことを使って HDP-LDA の全条件付分布を導出してみる。
ちなみに今回の導出は英語版 Wikipedia の LDA の項でも行われているので、実は LDA だけの話ならそっちを見てもらえばOK。

*1:できなくはないはずなので、ちょっと試してみたい気もする