Hatena::ブログ(Diary)

チョッキーの黒歴史

2009年12月31日

[][][]マハラノビス距離によるクラス分類

マハラノビス距離を計算することで入力データのクラスを分類するプログラムをRで実装してみた。

多クラスでしかも非線形に分類できました。


方法は超単純。

クラスの重心を求めて、入力データと重心のマハラノビス距離を求めればいいだけ。

距離が最も短いクラスに分類する。

多クラスへの拡張が超簡単。しかも非線形!!!

パーセプトロンの学習規則やらWidorow-Hoffの学習規則だと線形な分離面しか描けなかった。)


クラスcの重心を¥mu_{c}、共分散行列をS_{c}とすると、

入力データxとクラスcのマハラノビス距離は次の式から求められる。

D_{c}^2 = ¥{ (x-¥mu_{c}) S_{c}^{-1} (x-¥mu_{c}) ¥}

Rでの実装

#マハラノビス距離で入力データを分類する
# x:入力データ, data:学習データ, t:学習データのクラスベクトル
my.mahalanobis <- function(x,data,t){
  tmax <- max(t)
  tmin <- min(t)
  if( tmin != 1 ){
   cat("クラスは1から始まる1以上の整数にすること。")
   return (0)
  }
  d.list <- c() # 各クラスとのマハラノビス距離を格納
  for( c in tmin:tmax ){          # 工夫次第ではこのfor文は消せる?
    data.c <- data[t==c,]         # クラスcのデータのみ抽出
    mean.c <- apply(data.c,2,mean)# 平均ベクトルを求める
    var.c <- var(data.c)          # 共分散行列を求める
    inv.c <- solve(var.c)         # 逆行列を求める
    # 各クラスとのマハラノビス距離を求める( sqrt は省略 )
    x.mean <- x - mean.c          # 平均ベクトルとの差
    dis <- ( x.mean %*% inv.c %*% x.mean )
    d.list<- c(d.list,dis[1,1])   # 行列になっているのでスカラーに変更して格納
  }
  return ( order(d.list)[1] )
}

Rには関数mahalanobis(x,mu,S)があるけどこれ使うと楽すぎるのでしない。

ただし、逆行列はsolve(S)で求めます。

逆行列を自力で求めるのはちょっとめんどい。

分類したい入力データ数が少ないときは上のでもいいけど、

たくさんある場合は毎回平均ベクトルと分散求めているのがうざい。

ってことで、各クラスの平均ベクトルと共分散行列の逆行列を格納したリストをモデルとして

保存するようにする。

#判別用のモデルを返す
# data:学習データ, t:学習データのクラスベクトル
md.makemodel <- function(data,t){
  tmax <- max(t)
  tmin <- min(t)
  if( tmin != 1 ){
   cat("クラスは1から始まる1以上の整数にすること。")
   return (0)
  }
  mean.list <- list() # 各クラスの平均ベクトルを格納
  inv.list <- list()  # 各クラスの共分散行列の逆行列を格納
  for( c in tmin:tmax ){          # 工夫次第ではこのfor文は消せる?
    data.c <- data[t==c,]         # クラスcのデータのみ抽出
    mean.c <- apply(data.c,2,mean)# 平均ベクトルを求める
    var.c <- var(data.c)          # 共分散行列を求める
    inv.c <- solve(var.c)         # 逆行列を求める
    mean.list[[c]] <- mean.c
    inv.list[[c]] <- inv.c
  }
  return ( list(mean.list,inv.list) )
}
#判別用のモデルを用いてクラス分類
# x:入力データ, modle:md.makemodelで作ったモデル
md.predict <- function(x,model){
  mean.list <- model[[1]] # 各クラスの平均ベクトルを格納したリスト
  inv.list <- model[[2]]  # 各クラスの共分散行列の逆行列を格納したリスト
  cls <- length(mean.list)# クラスの数
  d.list <- numeric(cls)  # 各クラスとのマハラノビス距離を格納
  # 各クラスとのマハラノビス距離を求める( sqrt は省略 )
  for( c in 1:cls ){
    x.mean <- x - mean.list[[c]]
    dis <- ( x.mean %*% inv.list[[c]] %*% x.mean )
    d.list[c] <- dis[1,1]
  }
  return ( order(d.list)[1] )
}
#プロット関数 2次元2クラス用
# model:学習結果, data:学習データ, t:学習データのクラスベクトル
myMD.plot <- function(model,data,t){
  a <- 200 # プロットデータの個数,大きくすると時間がかかる
  tx <- seq(0,9,length=a)
  ty <- tx 
  rapMyMD <- function(x,y){
    return ( md.predict(c(x,y),model) )
  }
  z <- matrix(0,length(ty),length(tx))
  for( y in 1:length(tx) ){
    for( x in 1:length(ty) ){
      z[x,y] <- rapMyMD(tx[x],ty[y])
    }
  }
  # 分類結果の領域を描画
  image(tx,ty,z,col=c("#00FF80","#FFBF00"),axes=TRUE,xlab="x",ylab="y")
  # 学習データ点をプロット
  points( data[,1], data[,2], col=ifelse(t==1,"red","blue"), pch=ifelse(t==1,16,17) )
  title("myMD:[testData0.csv]")
  return (z)
}
#プロット関数 2次元3クラス用
# model:学習結果, data:学習データ, t:学習データのクラスベクトル
myMD.plot2 <- function(model,data,t){
  a <- 200 # プロットデータの個数,大きくすると時間がかかる
  tx <- seq(0,1,length=a)
  ty <- tx 
  rapMyMD <- function(x,y){
    return ( md.predict(c(x,y),model) )
  }
  z <- matrix(0,length(ty),length(tx))
  for( y in 1:length(tx) ){
    for( x in 1:length(ty) ){
      z[x,y] <- rapMyMD(tx[x],ty[y])
    }
  }
  # 分類結果の領域を描画
  image(tx,ty,z,col=c("#FFBF00","#00FF80",8),axes=TRUE,xlab="x",ylab="y")
  # 学習データ点をプロット
  points( data[,1], data[,2], col=ifelse(t==1,"blue",ifelse(t==2,"red",6)), pch=ifelse(t==1,17,ifelse(t==2,16,15)) )
  title("myMD:[testData3.csv]")
  return (z)
}

まずは、いつも通り「testData0.csv」で試す。

data.src <- read.csv("testData0.csv")
data <- data.src[,1:2]
t <- data.src[,3]
model <- md.makemodel(data,t)
z <- myMD.plot(model,data,t)

実行するとこんな感じのプロット図ができた。

f:id:tyokyNN:20091231143902j:image

非線形な分離面キタ━━━━(゚∀゚)━━━━ !!!!!

多クラス分類もいけるはず!

次のデータ「testData3.csv」で試す!!(2次元3クラス)

x,y,class

0.05,0.682,1

0.166,0.682,1

0.126,0.78,1

0.286,0.764,1

0.25,0.622,1

0.056,0.596,1

0.392,0.726,1

0.18,0.618,1

0.32,0.7,1

0.208,0.744,1

0.444,0.1,2

0.288,0.292,2

0.658,0.466,2

0.774,0.194,2

0.458,0.232,2

0.852,0.322,2

0.532,0.462,2

0.622,0.274,2

0.722,0.39,2

0.452,0.36,2

0.786,0.486,3

0.59,0.652,3

0.378,0.478,3

0.41,0.636,3

0.712,0.606,3

0.514,0.518,3

0.53,0.716,3

0.752,0.716,3

0.538,0.848,3

0.862,0.618,3

data.src <- read.csv("testData3.csv")
data <- data.src[,1:2]
t <- data.src[,3]
model <- md.makemodel(data,t)
z <- myMD.plot2(model,data,t)

実行するとこんなのが出た。

f:id:tyokyNN:20091231144052j:image

多クラス非線形キタ━━━━(゚∀゚)━━━━ !!!!!

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証

トラックバック - http://d.hatena.ne.jp/tyokyNN/20091231/1262238368