Hatena::ブログ(Diary)

ほくそ笑む このページをアンテナに追加 RSSフィード

2011-03-25

SVM のチューニングのしかた(2)

さて、前回は交差検証の説明で終わってしまいましたが、今回はちゃんと SVM のチューニングの話をします。

チューニングの手順としては、

  1. グリッドサーチで大雑把に検索する。
  2. 最適なパラメータがありそうなところを絞って再びグリッドサーチを行う。

という2段階のグリッドサーチを行います。

1段階目:グリッドサーチで大雑把に検索する

SVM のチューニングは tune.svm() という関数を用いて行います。

チューニングのやり方は、単純にグリッドサーチを行っているだけです。

パラメータの値をいろいろ変えてみて、正答率の一番いい値をベストパラメータとして出力します。

プログラムは下記のようになります。

gammaRange = 10^(-5:5)
costRange = 10^(-2:2)
t <- tune.svm(Species ~ ., data = iris, gamma=gammaRange, cost=costRange,
              tunecontrol = tune.control(sampling="cross", cross=8))
cat("- best parameters:\n")
cat("gamma =", t$best.parameters$gamma, "; cost =", t$best.parameters$cost, ";\n")
cat("accuracy:", 100 - t$best.performance * 100, "%\n\n")
plot(t, transform.x=log10, transform.y=log10)

まずはグリッドサーチするために、gamma と cost の範囲を決めています。ここでは gamma は 10^{-5}10^5 まで、cost は 10^{-2}10^2 までを指定しています。

次に tune.svm() でグリッドサーチを行います。svm() と同じような使い方ですが、交差検証を行うために tunecontrol = tune.control(sampling="cross", cross=8) と指定しています。これで 8-交差検証で評価を行うということになります。

あとはグリッドサーチの結果としてベストパラメータを出力しています。上記のプログラムを実行すると、結果は下記のように出ました。

- best parameters:

gamma = 0.1 ; cost = 1 ;

accuracy: 97.33187 %

gamma=0.1, cost=1 の組合せのとき、正答率 97.3 % を出していることがわかります。

ただし、1段階目ではベストパラメータの値はそれほど重視しません。

重要なのは最後の行、plot() 関数でグリッドサーチの結果を等高線図にしています。

等高線図は次のようになりました。

f:id:hoxo_m:20110325133036p:image

この図の中で色の濃い部分に最適なパラメータがありそうです。

そこで、gamma=10^{-1},cost=10^{0.5} のあたりと gamma=10^{-1.5},cost=10^{1.5} のあたりを再びグリッドサーチで調べてみましょう。これが2段階目となります。

2段階目:最適なパラメータがありそうなところを絞って再びグリッドサーチ

まずは gamma=10^{-1},cost=10^{0.5} のあたりをグリッドサーチしてみましょう。

gamma <- 10^(-1)
cost  <- 10^(0.5)
gammaRange <- 10^seq(log10(gamma)-1,log10(gamma)+1,length=11)[2:10]
costRange  <- 10^seq(log10(cost)-1 ,log10(cost)+1 ,length=11)[2:10]
t <- tune.svm(Species ~ ., data = iris, gamma=gammaRange, cost=costRange,
              tunecontrol = tune.control(sampling="cross", cross=8))
cat("[gamma =", gamma, ", cost =" , cost , "]\n")
cat("- best parameters:\n")
cat("gamma =", t$best.parameters$gamma, "; cost =", t$best.parameters$cost, ";\n")
cat("accuracy:", 100 - t$best.performance * 100, "%\n\n")
plot(t, transform.x=log10, transform.y=log10, zlim=c(0,0.1))

around [gamma = 0.1 , cost = 3.162278 ]

- best parameters:

gamma = 0.06309573 ; cost = 5.011872 ;

accuracy: 97.33187 %

gamma, cost, 正答率が出ました。

次は、gamma=10^{-1.5},cost=10^{1.5} のあたりをグリッドサーチしてみましょう。

gamma <- 10^(-1.5)
cost  <- 10^(1.5)
gammaRange <- 10^seq(log10(gamma)-1,log10(gamma)+1,length=11)[2:10]
costRange  <- 10^seq(log10(cost)-1 ,log10(cost)+1 ,length=11)[2:10]
t <- tune.svm(Species ~ ., data = iris, gamma=gammaRange, cost=costRange,
              tunecontrol = tune.control(sampling="cross", cross=8))
cat("[gamma =", gamma, ", cost =" , cost , "]\n")
cat("- best parameters:\n")
cat("gamma =", t$best.parameters$gamma, "; cost =", t$best.parameters$cost, ";\n")
cat("accuracy:", 100 - t$best.performance * 100, "%\n\n")
plot(t, transform.x=log10, transform.y=log10, zlim=c(0,0.1))

around [gamma = 0.03162278 , cost = 31.62278 ]

- best parameters:

gamma = 0.05011872 ; cost = 5.011872 ;

accuracy: 98.64766 %

こうして2回グリッドサーチをやりましたが、1回目の正答率は 97.3%、2回目の正答率は 98.6% と2回目の方が高いです。

というわけで、最適なパラメータとして、gamma = 0.05011872, cost = 5.011872 を選ぶことにします。

おわりに

以上で、SVM のチューニングが終わりました。上記のパラメータで SVM を作成するには下記のようにします。

gamma = 0.05011872 ; cost = 5.011872 ;
model <- svm(Species ~ ., data = iris, gamma=gamma, cost=cost)

これで、学習データを判別してみると

pred <- predict(model, iris)
table(pred, iris[,5])
pred         setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         49         1
  virginica       0          1        49

本に載っているより、判別精度が上がっていることもわかります。

以上です。

追記

Wikipedia の項目が交差検定から交差検証に変更されたため、それに合わせて記事内の用語も交差検証に変更しています。

みたにみたに 2012/05/24 12:55 大変有用な記事ありがとうございます。
1つ教えていただけないでしょうか。
「Rで学ぶデータサイエンス6 マシンラーニング」のサンプルプログラムでエラーがでてしまい対処できず困っています。
以下に、プログラムを書きます。
trains <- read.csv("R:\\train_diabetes.csv", header=F)
tests <- read.csv("R:\\test_diabetes.csv", header=F)
library(e1071)

model <- svm(factor(V8)~V1+V2+V3+V4+V5+V6+V7, data=trains, gamma=0.0000000001, cost=0.001, probabilitiy=TRUE)
pred <- predict(model, trains, probability=TRUE)
p <- attr(pred, "probabilities")
ee <- data.frame(pred)
write.table(ee, "R:\\out(訓練標本).txt", append=T, quote=F, col.names=F)
kpred <- predict(model, newdata=tests, probability=TRUE)
pp <- attr(kpred, "probabilities")
dd <- data.frame(kpred)
write.table(dd, "R:\\out(検証標本).txt", append=T, quote=F, col.names=F)

-------------------------
train_diabetes.csv
-------------------------
npreg,glu,bp,skin,bmi,ped,age,type
5,86,68,28,30.2,0.364,24,0
7,195,70,33,25.1,0.163,55,1
5,77,82,41,35.8,0.156,35,0
0,165,76,43,47.9,0.259,26,0
0,107,60,25,26.4,0.133,23,0
5,97,76,27,35.6,0.378,52,1
3,83,58,31,34.3,0.336,25,0
1,193,50,16,25.9,0.655,24,0
3,142,80,15,32.4,0.200,63,0
2,128,78,37,43.3,1.224,31,1
-------------------------


-------------------------
test_diabetes.csv
-------------------------
npreg,glu,bp,skin,bmi,ped,age,type
6,148,72,35,33.6,0.627,50,1
1,85,66,29,26.6,0.351,31,0
1,89,66,23,28.1,0.167,21,0
3,78,50,32,31.0,0.248,26,1
2,197,70,45,30.5,0.158,53,1
5,166,72,19,25.8,0.587,51,1
0,118,84,47,45.8,0.551,31,1
1,103,30,38,43.3,0.183,33,0
3,126,88,41,39.3,0.704,27,0
9,119,80,35,29.0,0.263,29,1
1,97,66,15,23.2,0.487,22,0
5,109,75,26,36.0,0.546,60,0
3,88,58,11,24.8,0.267,22,0
0,122,78,31,27.6,0.512,45,0
4,103,60,33,24.0,0.966,33,0
9,102,76,37,32.9,0.665,46,1
2,90,68,42,38.2,0.503,27,1
4,111,72,47,37.1,1.390,56,1
3,180,64,25,34.0,0.271,26,0
7,106,92,18,22.7,0.235,48,0
-------------------------

上記のプログラムで、以下のようなエラーになります。
以下にエラー predict.svm(model, newdata = tests, probability = TRUE) :
test data does not match model !

どうすればいいでしょうか?

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証