Hatena::ブログ(Diary)

明日とロボット このページをアンテナに追加 RSSフィード

2011-06-28

Rでニューラルネットワークを用いた手書き数字認識

アイディア

  • 元々のアイディアはCodeZineの記事ニューラルネットワークを用いたパターン認識で公開されていたJavaアプレット
  • 学習精度を上げるために、自分で書いた文字を元に学習させる。
    • それぞれの数を10回ずつ手書き入力させ、データを用意し、それを元にニューラルネットワークを学習させる
  • Rで用意されているneuralパッケージを使用する

手書き入力データの用意

  • CodeZineさんの記事のように、手書き入力した数字を7×11のグリッドで表現

f:id:astrobot:20110629002141p:image

  • 作成したJavaアプレットを用い、RBFモデルのネットワークで学習するための、入力層のユニットは7×11の計77、出力層のユニットに0から9の数を設定したデータを用意する。
  • 0〜9間での数字をそれぞれ10回ずつ書いて、以下のような学習用のcsvファイル(data.csv)を作成
cell0cell1cell2...cell76ans0ans1ans2...ans9
0.10.10.1...0.10.90.10.1...0.1
..............................
0.10.10.1...0.10.10.90.1...0.1
..............................
    • 塗りつぶされたcellは0.9、空白だったセルは0.1で一次元配列にして保存
    • それぞれの値に対応するansの値が0.9、それ以外が0.1となるように
  • 同様に、0〜9までのテスト用のcsvファイル(test.csv)も作成

cell0cell1cell2...cell76
0.10.10.1...0.1
0.10.10.1...0.9
...............

Rを使ったRBFモデルのニューラルネットワーク

  • 今回のように高次の交互作用が存在する課題は、ユニット内の変換にラジアル基底関数(RBF)など釣り鐘式の関数を利用すると学習が容易に成立する事が知られている。
  • package neuralをRにインストール
options(CRAN="http://cran.r-project.org")
install.packages("neural")
    • 通常ならば以上のようにパッケージをインストールできるが、neuralは古くなっためか退避されていていたため、zipを直接ダウンロードしてインストールする。
データを読み込む
  • neural libraryを読み込んで、dataに学習用csvを読み込み
  • inputに7×11の配列データ、outputにそれぞれに対応する数を読み込む
library(neural)
data<-read.csv("data.csv",header=T)
input<-as.matrix(data[0:77])
output<-as.matrix(data[78:87])
neuralのGUIを使ってネットワークの分析
  • 初期値のシードを固定し、中間層を2層(5個と3個)で学習の最大繰り替えし数を300回に指定
set.seed(1)
learn0<-mlptrain(input, neurons=c(5,3), output$ans1, it=300)
learn1<-mlptrain(input, neurons=c(5,3), output$ans2, it=300)
...
learn9<-mlptrain(input, neurons=c(5,3), output$ans9, it=300)
  • すると以下のようなGUIが表示されるため、startをクリック

f:id:astrobot:20110629010708p:image

  • 自分のノートパソコンでだいたい、それぞれ10分程度。
    • 終了すると、出力ユニットボックスが赤くなるので、そうしたらexit

学習データのセーブ

  • 以下のようにセーブできる
save(learn0, file="learn0.dat") 
save(learn1, file="learn1.dat") 
...
save(learn9, file="learn9.dat") 

精度テスト

  • テスト用csvを読み込んで、学習の完了したネットワークで値予測を行う
    • それぞれの予測値を取り込み、ひとつに結合し、表示、csv出力
test<-read.csv("test.csv",header=T)
testdata<-as.matrix(test[0:77])

applied0<-mlp(testdata, learn0$weight, learn0$dist,learn0$neurons, learn0$actfns)
applied1<-mlp(testdata, learn1$weight, learn1$dist,learn1$neurons, learn1$actfns)
...
applied9<-mlp(testdata, learn9$weight, learn9$dist,learn9$neurons, learn9$actfns)

result<-cbind(applied0,applied1,applied2,applied3,applied4,applied5,applied6,applied7,applied8,applied9)
result

write.csv(result,"result.csv")
  • 結果は以下のようになった。
prediction\actual #0123456789
00.8380.0880.0870.0900.0940.0910.0910.0990.0850.302
10.0900.5940.1230.1400.0930.1030.0870.0960.0850.091
20.1420.1270.9070.0920.1090.0990.1080.0930.1090.764
30.0880.0960.1460.7680.1060.0960.0940.1080.0930.129
40.1040.1100.0910.0890.8820.1000.1460.0950.1140.106
50.0920.0990.1050.0880.1010.9370.1000.0910.1070.090
60.1080.1020.1020.2370.1220.3490.8600.0920.0900.090
70.1070.1840.1860.0940.1590.2520.0970.7600.1020.236
80.1000.0930.0890.0890.5940.0870.0880.0960.1060.093
90.0880.2970.0910.1300.1770.1020.0890.1130.0940.852
  • 8を除いた全ての値でそれなりの精度の予測ができた。
  • 改善点はこれらの作業をJAVAから全自動化

参考

2011-06-27

統計処理ソフトウェアRをEclipseでJavaから使う

Rとは

  • 統計用途に特化されたオープンソースのプログラミング言語。
  • 数多くの統計パッケージが無償で公開されている
  • 欠点は、GUIが乏しかったり、大きすぎるデータ処理は難しい点

ダウンロード・インストール

  • こちらより最新版をダウンロード
    • 32bit/64bitは後に選択できる。
  • 以下のように下の方にインストールしとくと後々便利
C:\R\

Javaから使うために

  • rJavaというパッケージをRにインストールする。
install.packages('rJava')
  • ダウンロード元を聞かれるので、適当に一つ選ぶ。(日本ならtsukubaとか)

環境変数の設定

  • マイコンピュータを右クリックし、properties
  • advanced system settingsをクリックし、advancedタブ下にあるEnviroment variablesをクリック
  • 以下を参考に環境変数を設定

R_HOME: C:\R\R-2.13.0
PATH: %R_HOME%\bin\i386かx64 (使っているR.dllがあるところ)
R_DOC_DIR: %R_HOME%\doc
R_INCLUDE_DIR: %R_HOME%\include
R_SHARE_DIR: %R_HOME%\share


Eclipseの設定

  • Eclipseを管理者権限で起動。
  • File -> new -> Java Projectで、プロジェクト名を指定してnext
  • LibrariesのタブでAdd External Jarsで、先ほどインストールしたパッケージの中からIRI.jarを選択しFinish
    • 場所:インストールフォルダ\R-2.13.0\library\rJava\jri\IRI.jar
  • 作成したプロジェクトにC:\R\R-2.13.0\library\rJava\jriにあるjri.dllをコピー。

テスト(こちらのコードを改変)

  • JavaプログラムをEclipseで実行して、コンソールからRのコードを入力、実行
package jritest;
import java.util.Scanner;

import org.rosuda.JRI.REXP;
import org.rosuda.JRI.Rengine;
import org.rosuda.JRI.RMainLoopCallbacks;

public class nlr {
	static void print(String s) {
		System.err.println(s);
	}

	static void eval(Rengine r, String s) {
		r.eval(s, false);
	}

	public static void main(String[] args) {
		Scanner stdIn = new Scanner(System.in);
		String input;

		if (!Rengine.versionCheck()) {
			System.err.println("** Version mismatch - Java files don't match library version.");
			System.exit(1);
		}

		print("creating Rengine");

		Rengine re = new Rengine(args, false, null);

		print("Rengine created, waiting for R");

		// the engine creates R is a new thread, so we should wait until it's
		// ready
		if (!re.waitForR()) {
			System.out.println("Cannot load R");
			return;
		}

		print("R is running");
		print("Enter command for R (exit by typing \"end\")");
		while(true){
			  input = stdIn.nextLine();
			  if(input.equals("end"))
				  break;
			  System.out.println(re.eval(input));
		}
		
		print("stopping Rengine");

		re.end(); // stop the Rengine

		return;
	}
}
  • 入力例
data(iris)
pdf('/tmp/graph.pdf')
print(stripchart(iris[, 1:4], method = 'stack', pch = 16, cex = 0.4, offset = 0.6))
dev.off()
  • プロジェクトのフォルダにtmpフォルダが生成され、その中に、graph.pdfができる。

f:id:astrobot:20110628000026p:image

参考