Hatena::ブログ(Diary)

My Life as a Mock Quant このページをアンテナに追加 RSSフィード Twitter

2013-02-03

R言語でスライスサンプリング(Slice Sampling)を実装してみた

| 19:48 | R言語でスライスサンプリング(Slice Sampling)を実装してみたを含むブックマーク

スライスサンプリング(Slice Sampling)というサンプリング手法

Slice sampling, Radford M. Neal, Source: Ann. Statist. Volume 31, Number 3 (2003), 705-767.

[physics/0009028] Slice Sampling

についてお勉強していたのでまとめる。以下の文章で「原論文」としてこの論文を参照する。

ちなみに1P程度ではあるものの皆大好きPRML・下の11章にも記載がある。

概要&アルゴリズム

手法としてはMCMCの一種と考えられるものであるので、まずサンプリング対象として欲しい確率分布(の規格化定数を除いた部分) f(x)を設定する。

この時、 f(x)に従うサンプル系列 x=¥{x_1,x_2,¥dots¥}を得るために、スライスサンプリングでは以下のようなアルゴリズムを考える。

  • 初期時点 t=0として、その時の初期値 x_0を適当に設定
  •  y ¥sim Unif(0 , f(x_t))となるような一様分布 Unifから yサンプリングし、"スライス" S = ¥left¥{ x;y < f(x)¥}を決定
  • 区間 I = (L,R)をなんらかの方法で決める
  •  x_tを領域 S ¥cap I から一様にサンプリングして持ってくる

というものである。もう少し細かい実装のお話は次の節に書く。

ざっくりでいうと

 p(x,y) = ¥frac{1_{¥{0<y<f(x)¥}}}{Z},Z=¥int f(x)dx

という同時確率分布 p(x,y)からのサンプリングギブスサンプラーを用いて実施し、直接は必要のない yからのサンプリング結果を無視することで

周辺分布[p(x)]として

p(x)=¥int ¥frac{1_{¥{0<y<f(x)¥}}}{Z} dy = ¥int_0^{f(x)} ¥frac{1}{Z} dy = ¥frac{f(x)}{Z}

となるので目標とする分布からのサンプリングが達成されるというものである。

ここでZは規格化定数、 1_xxがtrueの時に1,そうでなければ0を取るような関数である。

より詳細については原論文を参照していただきたい。

R言語での実装

基本的に原論文、特にFigure3〜5あたりに合わせて実装してある。

注意すべき点としては

  • アンダーフローしないように f(x)のlogを取って計算するための方法で実装(原論文P8半ば参照)
  • 区間 Iの選び方として、"doubling"手順ではなく"stepping out"手順を使用している
  • 区間 Iを無限大まで伸ばして良いと仮定したので、Figure3のJとKに関する処理は無視(原論文P10半ば参照)
  • 区間 Iを決める際の幅は、過去の点間の距離の平均値で決定(原論文P16最終段落参照)*1

である。

これを踏まえた上で以下のようにスライスサンプリングを以下のように実装した。

#The "stepping out" procedure for finding an interval around x
stepping.out <- function(x, w, is.in.S)
{
  #initial range
  L <- x - w * runif(1)
  R <- L + w
  #find inerval around x
  while(is.in.S(L)){
    L <- L - w
  }
  while(is.in.S(R)){
    R <- R + w
  }
  list(L=L, R=R)
}
#The "shrinkage" procedure for sampling from the interval 
shrinkage <- function(x0, I, is.in.S)
{
  L.bar <- I$L
  R.bar <- I$R
  repeat
  {
    #select new point from the interval between L.bar and R.bar
    x1 <- L.bar + runif(1) * (R.bar - L.bar)
    if(is.in.S(x1)){break}
    #shrinkage the interval
    if(x1 < x0){
      L.bar <- x1
    }
    else{
      R.bar <- x1
    }
  }
  x1
}
#Slice sampling function
slice.sample <- function(n, x0, f, w=1.0)
{  
  g <- function(x){log(f(x))}
  make.is.in.S <- function(z, g){function(x){z < g(x)}}
  sum.dist <- 0
  result <- rep(x0,n)
  for(i in 2:n)
  { 
    z <- g(x0) - rexp(1)
    is.in.S <- make.is.in.S(z,g)
    #calc interval
    I <- stepping.out(x0, w, is.in.S)
    #sample next point
    x1 <- shrinkage(x0, I, is.in.S)
    #update results
    sum.dist <- sum.dist + abs(x1 - x0)
    w <- sum.dist / (i - 1)
    result[i] <- x1
    x0 <- x1
  }
  result
}  

実際に動かしてみる

まずは平均0・標準偏差1の正規分布を生成してみる。

MCMC同様(というかこの手法自体がその一部なので)、確率密度関数の定数項は無視して良い。

SIZE <- 10^4
points <- slice.sample(SIZE, 0, function(x)exp(-0.5*x^2), w=1)

作成したサンプリングポイント(点列)の平均と標準偏差

> mean(points)
[1] 0.009999361
> sd(points)
[1] 1.018237

のようにそれぞれ0と1に近い値を取っている。更にサンプリングした点列のヒストグラムと密度関数を重ねてPLOTすると

hist(points, SIZE^0.5, freq=FALSE)
x<-seq(-3 ,3 ,0.01)
lines(x, dnorm(x), col=2, lwd=3)

f:id:teramonagi:20130203194625p:image

となる。ほぼ重なっているので答えとしても良さそうだ。


次に混合正規分布を作成してみる。

まずは平均が−3と3の位置にある5:5の混合正規分布を作成し、ヒストグラムと密度関数を重ねてPLOTしてみる。

mu <- 3
SIZE <- 10^4
points <- slice.sample(SIZE, 0, function(x)exp(-0.5*(x+mu)^2)*0.8 + exp(-0.5*(x-mu)^2)*0.2, w=1)
hist(points, SIZE^0.5, freq=FALSE,xlim=c(-(mu+3), mu+3))
x<-seq(-(mu+3), mu+3 ,0.01)
lines(x, 0.2*dnorm(x-mu)+0.8*dnorm(x+mu), col=2, lwd=3)

f:id:teramonagi:20130203194626p:image

確かに混合正規分布もうまくできてそうだ。

ここからもう少し分布間の幅を広げ、平均を−6と6へ変更してみると・・・

mu <- 6
SIZE <- 10^4
points <- slice.sample(SIZE, 0, function(x)exp(-0.5*(x+mu)^2)*0.8 + exp(-0.5*(x-mu)^2)*0.2, w=1)
hist(points, SIZE^0.5, freq=FALSE,xlim=c(-(mu+3), mu+3))
x<-seq(-(mu+3), mu+3 ,0.01)
lines(x, 0.2*dnorm(x-mu)+0.8*dnorm(x+mu), col=2, lwd=3)

f:id:teramonagi:20130203194627p:image

やはり幅が広い(≒多峰性の系)だとこの手法でもうまくいかないことがわかる。

汎用的なサンプリング手法への道は険しい。。。

参考

*1:これは単峰系のみという注釈が原論文にあったので、やらない方がいいかも

maemae 2017/02/23 20:05 基礎的な質問で恐縮ですが、`z <- g(x0) - rexp(1)`の` - rexp(1)`では、なぜ指数分布からの乱数を引いているのでしょうか。
評価したい関数がg(x)だとすると、 z ~ Uniform(0, g(x)) を計算する必要があると思うのですが、なぜこのような形になるのでしょうか。