書籍「オンライン機械学習」を買ったのでCommon Lispで実装してみた。(AROW編)

前回の記事ではパーセプトロンと線形SVMを実装したが、より高度な手法として、CW、AROW、SCWといったものがあるらしい。
パーセプトロン等では重みベクトルを直接更新していたのに対して、CWでは重みベクトルが正規分布に従って分布していると仮定して、その正規分布の平均や分散を更新していく。この平均と分散はそれぞれ重みベクトルと特徴量ごとの信頼度に対応し、これがCW(Confidence Weighted)の名前の由来となっている。更新にあたっては、誤分類の確率を一定値以下にするという制約を付けた上で、推定された正規分布とのカルバックライブラーダイバージェンス(分布間の距離的なもの)を最小化するような正規分布の平均と分散を求める。
今回実装するAROWやSCWはCWの派生で、CWがノイズが弱いという問題点を解決した手法である。
AROWやSCWについては解説した記事があるのでそちらを。

精度はSCW > AROWだけど、SCWの論文中の比較表を見るにそれほど大した差ではないし、AROWの方が優れている場合もある。SCWはハイパーパラメータが2つでグリッドサーチしなければならないのに対して、AROWはハイパーパラメータが1つで扱いやすい。

AROWの疑似コード


実装上の注意点としては、共分散行列の更新は行列演算なので素直にやると計算量がO(m^3)になるが、対角成分以外を0として対角行列に近似することでO(m)にできるということがある。実際には、対角行列は行列ではなくベクトルとして実装し、対角行列とベクトルの乗算をする関数diagonal-matrix-multiplicationを定義しておく
すると上記の疑似コードの対応する1ステップの更新部分はこうなる。

;; AROW、入力ベクトルに1次元足すバージョン(バイアスを分けない)
;; muとsigmaが破壊的に更新される。gammaがハイパーパラメータ
(defun train-arow-1step (input mu sigma gamma training-label tmp-vec1 tmp-vec2)
  (let ((loss (- 1d0 (* training-label (inner-product mu input))))) ; ヒンジ損失が0より大きいときに更新
    (if (> loss 0d0)
      (let* ((beta (/ 1d0 (+ (inner-product (diagonal-matrix-multiplication sigma input tmp-vec1) input) gamma)))
	     (alpha (* loss beta)))
	;; muの更新
	;;   betaの計算の時点で、tmp-vec1には Sigma_{t-1} x_t の結果が入っている
	;;   muの更新差分をtmp-vec2に入れる
	(v-scale tmp-vec1 (* alpha training-label) tmp-vec2)
	(v+ mu tmp-vec2 mu)	
	;; sigmaの更新
	;;   \Sigma_{t-1} x_t x_t^T \Sigma_{t-1} を対角行列に近似したものをtmp-vec1に入れる
	(diagonal-matrix-multiplication tmp-vec1 tmp-vec1 tmp-vec1)
	;;   betaをかけてtmp-vec1を更新
	(v-scale tmp-vec1 beta tmp-vec1)
	;;   sigmaを更新
	(v- sigma tmp-vec1 sigma)))
    (values mu sigma)))

このままだとバイアスに対応する平均と分散がmuとsigmaに含まれているため、データセットの各入力ベクトルに1次元、定数要素を追加する前処理が必要になりなんか嫌である。バイアスの平均と分散の更新は分けて、2要素のベクトルmu0-sigma0-vecに格納することにすると、上の関数はこうなる。

(defun train-arow-1step (input mu sigma mu0-sigma0-vec gamma training-label tmp-vec1 tmp-vec2)
  (let* ((mu0 (aref mu0-sigma0-vec 0))
	 (sigma0 (aref mu0-sigma0-vec 1))
	 (loss (- 1d0 (* training-label (f input mu mu0))))) ; ヒンジ損失が0より大きいときに更新
    (if (> loss 0d0)
      (let* ((beta (/ 1d0 (+ sigma0 (inner-product (diagonal-matrix-multiplication sigma input tmp-vec1) input) gamma)))
	     (alpha (* loss beta)))
	;; muの更新
	;;   betaの計算の時点で、tmp-vec1には Sigma_{t-1} x_t の結果が入っている
	;;   muの更新差分をtmp-vec2に入れる
	(v-scale tmp-vec1 (* alpha training-label) tmp-vec2)
	(v+ mu tmp-vec2 mu)

	;; mu0の更新
	(setf (aref mu0-sigma0-vec 0) (+ mu0 (* alpha sigma0 training-label)))

	;; sigmaの更新
	;;   Sigma_{t-1} x_t x_t^T Sigma_{t-1} を対角行列に近似したものをtmp-vec1に入れる
	(diagonal-matrix-multiplication tmp-vec1 tmp-vec1 tmp-vec1)
	;;   betaをかけてtmp-vec1を更新
	(v-scale tmp-vec1 beta tmp-vec1)
	;;   sigmaを更新
	(v- sigma tmp-vec1 sigma)

	;; sigma0の更新
	(setf (aref mu0-sigma0-vec 1) (- sigma0 (* beta sigma0 sigma0)))))
    (values mu sigma mu0-sigma0-vec)))

あとはデータセットにこれを適用する部分を書くだけである。

(defun train-arow-all (training-data mu sigma mu0-sigma0-vec gamma tmp-vec1 tmp-vec2)
  (loop for datum in training-data do
    (train-arow-1step (cdr datum) mu sigma mu0-sigma0-vec gamma (car datum) tmp-vec1 tmp-vec2))
  (values mu sigma mu0-sigma0-vec))

(defun train-arow (training-data gamma)
  (let* ((dim (length (cdar training-data)))
	 (mu       (make-dvec dim 0d0))
	 (sigma    (make-dvec dim 1d0))
	 (mu0-sigma0-vec (make-array 2 :element-type 'double-float :initial-contents '(0d0 1d0)))
	 (tmp-vec1 (make-dvec dim 0d0))
	 (tmp-vec2 (make-dvec dim 0d0)))
    (train-arow-all training-data mu sigma mu0-sigma0-vec gamma tmp-vec1 tmp-vec2)))

前回同様にして読み込んだデータセットで試してみると

(time
 (multiple-value-bind (mu sigma mu0-sigma0-vec)
     (train-arow a1a-train 10d0)
   (declare (ignore sigma))
   (test a1a-test mu (aref mu0-sigma0-vec 0))))
; Accuracy: 84.461815%, Correct: 26146, Total: 30956
; Evaluation took:
;   0.013 seconds of real time
;   0.013124 seconds of total run time (0.013124 user, 0.000000 system)
;   100.00% CPU
;   42,935,095 processor cycles
;   1,900,544 bytes consed

(time
 (multiple-value-bind (mu sigma mu0-sigma0-vec)
     (train-arow a9a-train 10d0)
   (declare (ignore sigma))
   (test a9a-test mu (aref mu0-sigma0-vec 0))))
; Accuracy: 84.94564%, Correct: 13830, Total: 16281
; Evaluation took:
;   0.032 seconds of real time
;   0.032499 seconds of total run time (0.032499 user, 0.000000 system)
;   100.00% CPU
;   107,560,608 processor cycles
;   9,076,736 bytes consed

多少計算時間が増えるものの、適当に設定したgammaでも良好な精度を出すことが分かる。線形SVMよりハイパーパラメータも減って精度もいいので常用できそうである。
今後は、SCW、平均化確率的勾配降下法(ASGD)、疎ベクトルの演算あたりをやるかもしれない。