Negative/Positive Thinking

2011-09-23

PA-IIとSVMの比較

はじめに

Passive-AggressiveアルゴリズムLIBSVMでどの程度の性能差がでるのか試してみた。

使用したデータ

LIBSVMの分類精度

#学習
$ svm-train -t 0 a9a.txt a9a.model
#予測
$ svm-predict a9a.t a9a.model out
Accuracy = 84.9764% (13835/16281) (classification)」

自分で書いたPA-IIのコードでの分類精度

  • アグレッシブパラメータCを変えて、学習データは全体を10回学習した場合
$ perl PA-II.pl a9a.txt 1.0 < a9a.t
Result : 0.762483876911738 (12414/16281)

$ perl PA-II.pl a9a.txt 0.1 < a9a.t
Result : 0.796449849517843 (12967/16281)

$ perl PA-II.pl a9a.txt 0.01 < a9a.t
Result : 0.83483815490449 (13592/16281)

$ perl PA-II.pl a9a.txt 0.001 < a9a.t
Result : 0.84835083840059 (13812/16281)

$ perl PA-II.pl a9a.txt 0.0001 < a9a.t
Result : 0.850930532522572 (13854/16281)

$ perl PA-II.pl a9a.txt 0.00001 < a9a.t
Result : 0.840181807014311 (13679/16281)

$ perl PA-II.pl a9a.txt 0.000001 < a9a.t
Result : 0.764879307167864 (12453/16281)

  • C=0.0001の時、何も調整していないLIBSVMの結果をギリギリ上回った
    • がんばれば、なかなかの精度がでるが

PAのコード

#! /usr/bin/perl
# Usage : perl PA-II.pl train_file parameter_C < test_file
use strict;
use warnings;

#学習ファイル名
my $train_file = shift;
#パラメータ
my $C = shift;

#訓練回数(学習データの個数*$loop個の学習を行う)
my $loop = 10;

#重みベクトル
my $w = {};

## 学習データの読み込み
my @x_list;
my @t_list;
open IN, $train_file;
while(<IN>){
    chomp;
    my @list = split(/\s+/, $_);
    push(@t_list, $list[0]);
    my $hash;
    for(my $i=1; $i<@list; $i++){
	my ($a, $b) = split(/:/,$list[$i]);
	$hash->{$a} = $b;
	$w->{$a} = 0;
    }
    push(@x_list, $hash);
}

## 訓練
while($loop--){
    for(my $i = 0; $i < @x_list; $i++){
	train($w, $x_list[$i], $t_list[$i]);
    }
}

## 推定
my $num = 0;
my $success = 0;
while(<>){
    chomp;
    my @list = split(/\s+/, $_);
    my $hash;
    for(my $i=1; $i<@list; $i++){
	my ($a, $b) = split(/:/,$list[$i]);
	$hash->{$a} = $b;
    }

    my $t = predict($w, $hash);
    $num++;
    if($t * $list[0] >= 0){
	$success++;
    }
}

## 結果の出力
print "Result : ",($success/$num)," (",$success,"/",$num,")\n";

##################################################
#予測
sub predict {
    my ($w, $x) = @_;
    
    my $y = 0;
    foreach my $f (keys %$x){
	if($w->{$f}){
	    $y += ($w->{$f} * $x->{$f});
	}
    }
    return $y;
}

#損失関数
sub loss {
    my ($w, $x, $t) = @_;

    my $y = 0;
    foreach my $f (keys %$x){
	if($w->{$f}){
	    $y += ($w->{$f} * $x->{$f});
	}
    }
    return 0.0 if(1.0 - $t*$y <= 0);
    return 1.0 - $t*$y;
}

#学習
sub train {
    my ($w, $x, $t) = @_;
    
    my $l = loss($w, $x, $t);
    my $sq_x = 0.0;
    foreach my $f (keys %$x){
	$sq_x += ($x->{$f} * $x->{$f});
    }

    ## PA
    #my $tau = $l / $sq_x;

    ## PA-I
    #my $tau = $l / $sq_x;
    #$tau = $C if($tau > $C);

    ## PA-II
    my $tau = $l / ($sq_x + 1.0 / (2 * $C));

    # 更新
    foreach my $f(keys %$x){
	$w->{$f} += $tau * ($t * $x->{$f});
    }
}

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証

トラックバック - http://d.hatena.ne.jp/jetbead/20110923/1316769132