ロジスティック回帰+確率的勾配降下法

次やってみたいことにロジスティック回帰を使おうとしているので、PRML 4章に従ってさらっと実装してみる。


最終的には Python + numpy で実装し直すことになると思うけど、R の手触り感が好きなので、今回は R。
データセットには R なら簡単に扱える iris を使う。iris は3クラスあるので、多クラスロジスティック回帰 (PRML 4.3.4) を実装することになる。
推論は IRLS (PRML 4.3.3) が PRML 的にはあうんだろうけど、ちょっと考えがあって確率的勾配降下法 (PRML 3.1.3, 5.2.4) を使うことにする。


まずは、こちらで説明した方法で iris のデータを正規化&行列化して使いやすくしておく。

xlist <- scale(iris[1:4])
tlist <- cbind(
	ifelse(iris[5]=="setosa",1,0),
	ifelse(iris[5]=="versicolor",1,0),
	ifelse(iris[5]=="virginica",1,0)
)
N <- nrow(xlist) # データ件数
K <- ncol(tlist) # クラス数


次は基底関数を決める。手始めに1次項+バイアスとしてみよう。
下のように書けば、phi のように定義した特徴関数を xlist の各行に適用して、N×M のいわゆる計画行列 PHI を生成することができる。

phi <- function(x) c(1, x[1], x[2], x[3], x[4])
PHI <- t(apply(xlist, 1, phi))  # NxM - design matrix
M <- ncol(PHI)  # 特徴数(特徴空間の次元)


パラメータとなる重み w を乱数で初期化しよう。
w は M×K の行列なので、正規乱数を使って初期化するなら次のように書けばいい。

w <- matrix(rnorm(M * K), M)


これで準備が整ったので、ロジスティック回帰を実装できる。
一番重要なのが、ロジスティック回帰の事後確率の仮定。多クラスなので、ソフトマックス関数の形をしている。

  • \displaystyle p(C_k|\boldsymbol{\phi})=y_k(\boldsymbol{\phi})=\frac{\exp(a_k)}{\sum_{j}\exp(a_j)}   (PRML (4.104) 式)

ただし a_k=\boldsymbol{w}^T\boldsymbol{\phi}_k である。
これは k=1, ..., K と動くから y がベクトルになることに気をつけると、次のように書ける。

# y_k = p(t=k|phi, w)  (k=1, ..., K)
y <- function(phi, w) {
	y <- exp(c(phi %*% w))
	return(y / sum(y))
}

# 例: n 番目の特徴ベクトルについて、その事後確率を求める
y_n <- y(PHI[n, ], w)


だいたいはこれで大丈夫なのだが、特徴を増やしたりするとうまくいかなくなったりし始める。
exp の結果が浮動小数点の範囲をオーバーフローして Inf に落ちてしまう現象が起きうる。具体的には exp の中身が 700 を超えるあたりから Inf になってしまうので、ちょっと大きめの特徴量があるだけで結構簡単に発生する。やっかい。
ロジスティック回帰に限らず exp を使うアルゴリズムで、学習後のパラメータがなぜか NaN だらけ……!? という現象が発生したら、まずは exp のオーバーフローを疑うといい。


ソフトマックス関数は正規化のために exp の総和で割り算するので、a_k たちに定数を足したり引いたりしても結果は変わらない。だから適当な数を全体から引いて、オーバーフローしない範囲に収めてしまえばこの問題を解決することができる。
しかし定数の加減ではまだオーバーフローする可能性は残ってしまう。というわけで、一番簡単には最大値を引いてしまえばいい。
というわけで、以下がその対応版。

y <- function(phi, w) {
	y <- c(phi %*% w)
	y <- exp(y - max(y))  # exp の中身から、その最大値を引く(オーバーフロー対策)
	return(y / sum(y))
}


次は誤差関数とその勾配を求める。
ロジスティック回帰の誤差関数は交差エントロピー誤差関数であり、これはおなじみの負の対数尤度と同等。

  • E(\boldsymbol{w}_1, \cdots, \boldsymbol{w}_K)=-\ln p(\boldsymbol{T}|\boldsymbol{w}_1, \cdots, \boldsymbol{w}_K)=-\sum_{n=1}^N\sum_{k=1}^K t_{nk}\ln y_{nk}   (PRML (4.108) 式)

これを微分した勾配は次のようになる。

  • \nabla_{\boldsymbol{w}_j}E(\boldsymbol{w}_1, \cdots, \boldsymbol{w}_K)=\sum_{n=1}^N(y_{nj}-t_{nj})\phi_n   (PRML (4.109) 式)

どちらの式にも y がでてくるが、これは上で実装した関数が使えるので、それぞれ次のように簡潔に書ける。
ただし n についての和は関数の外でとるようにしている(確率的勾配降下法にあわせて)。

En <- function(phi, t, w) -log(sum(y(phi, w) * t))
dEn <- function(phi, t, w) outer(phi, y(phi, w) - t)


R ではベクトルや行列に対して一度に演算ができる関係から、このように簡潔に書ける反面、元となる数式と項の順番を変えないといけなかったり、outer*1 とかにうまく置き換える必要があったりする。
このあたりはいきなりやれと言われても出来ないので、最初のうちは要素ごとに考えて、少し慣れてきてからでいいと思う。


今回学習に用いる「確率的勾配降下法」は次のようにパラメータ w を更新していくことで学習する手法。

  • \boldsymbol{w}^{(\tau+1)}=\boldsymbol{w}^{(\tau)}-\eta\nabla E_n   (PRML (3.22) 式)

ここで ηは「学習率」、E_n は特徴ベクトル φ_n に対する勾配。
学習率ηは、適当な値(0.1 とか 0.01 とか)からはじめて、徐々に小さくしていく。
この更新式を全てのデータ点について回していく。順序はシャッフルすることが望ましい。
というわけでデータ点を一周回して更新するのは次のような実装になる。

eta <- 0.1  # 学習率
for(n in sample(N)) {
	w <- w - eta * dEn(PHI[n,], tlist[n,], w)  # 確率的勾配降下法
}


勾配さえ計算できてしまえば、確率的勾配降下法はこのようにめちゃめちゃ簡単に実装できて実に嬉しい。が、簡単すぎて、こんなので学習とか本当にできるの? と少し不安になる。
そういう場合は可視化してみればいい。

ylist <- t(apply(PHI, 1, function(phi) y(phi, w)))
error <- sum(sapply(1:N, function(n) En(PHI[n,], tlist[n,], w)))
pairs(xlist, col=rgb(ylist), main=sprintf("Negative Log Likelihood = %.3f", error))


このコードで、試しに「初期化しただけの、全く学習させていない w (つまり乱数)」を使って分布図を描くと次のようになった。

次に、学習率η=0.1 でデータ点を一周だけ学習させたパラメータ w を使って描いた。

1周回しただけで、なかなかいい感じに分類が行われていることがわかる。誤差(負の対数尤度)も大きく下がっている。


いや、分類できてるっぽく見えるけど、正しく分類できているかどうかはこの分布図を見ただけではわからない。
そこで、正解(真のラベル)がちゃんとわかるように、「真に setosa」である点は▲という形状に、一方「setosa と予測」した点は赤でプロット、同様に「versicolorは ●-緑」、「virginicaは、□-青」と指定しつつ、もう少し細かいところを見られるように 2軸だけ取り出して可視化してみた。

plot(xlist[,c(1,2)], 
	col=rgb(ylist),
	pch=(tlist %*% c(17,16,22)),
	main=sprintf("Negative Log Likelihood = %.3f", error)
)


いくつか「緑っぽい□」と「青っぽい●」は見受けられるものの、正しく分類が進んでいる。
というわけで、ロジスティック回帰に確率的勾配降下法を使えば、「こんなに簡単でいいの?」という実装でちゃんと学習できることがわかる。
もうちょっと学習を進めたらどうなるか、特徴関数をもう少し増やしたらどうなるか、というあたりはまた次回。

*1:outer は「外積」ではなくて「直積」。ややこしいけど間違えないでね!