how to code something このページをアンテナに追加 RSSフィード

2012-04-27

決定木を計算する - ID3アルゴリズムの実装

決定木には以下のようなアルゴリズムがある。
(1)ID3
(2)C4.5
(3)CART
(4)複数分岐ID3
http://www.ccn.yamanashi.ac.jp/~munehisa/kenkyuu/soturon_2004/suzuki.pdf
アルゴリズムの特徴は以下のとおり。

CARTは計算できる範囲のデータ数で, 時間を気にしないならば最も良いアルゴリズムである
・C4.5は短い計算時間である程度よいの正解率を導けるアルゴリズムである
・ID3は CART ほどの正解率は望めないが C4.5 よりはよい正解率を得られる

ということで、とりあえずID3を実装してみます。
ID3は、「最低限の仮説による事象の決定=オッカムの剃刀」を行うアルゴリズム
流れを以下に示す。

1.ルートノード N を作成して全ての例題集合を N に所属させる。
2.もしN に所属する例題が全て同じ決定 X を与えるなら N に X とラベル付けして処理を終了する。
3.例題集合 C に対する平均情報量を求める、即ち以下の式を計算する。

 M(C)=-¥sum_{x ¥subset D} p_x(C) log p_x(C)

4.C をある独立変数 ai の値に応じて分割する。
      ai が v1 … vm の m 通りの値を持つ場合は以下の通りに分割する。

 C_{ij}  ¥subset  C (a_i = v_j)

5.分割した Cij に応じて平均情報量を求める。

 M(C)=-¥sum_{x ¥subset D} p_x(C_{ij}) log p_x(C_{ij})

6.計算した平均情報量から独立変数 ai の平均情報量の期待値 Mi を以下の式で求める。

 Mi= M(C)  -¥sum_{j = 1}^m M(C_{ij}) ¥times ¥frac{¥mid C_{ij} ¥mid }{ ¥mid C¥mid}

7.Mi が最大となる独立変数を ak とする。
8.N のラベルを ak として、N の子ノード Nj を作成し、それぞれにCkj を所属させる。
9.それぞれの子ノードに対して N = Nj, C = Ckj として、2 以下の操作を再帰的に行う。

とりあえずwikipediaの例題を使う。
http://ja.wikipedia.org/wiki/ID3

f:id:seinzumtode:20120427155052p:image
この図で、食性、発生形態、体温から分類を決定するルールを作りたい。
とりあえずM1を計算するところまで書いた。
あとはイテレータを回すだけだが、少し考えないといけない。

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import math
from collections import defaultdict

#data = ['名称','食性(a1)','発生形態(a2)','体温(a3)','分類']
data = [ ['penguin','carnivorous','oviparity','isothermal','bird'],
       ['lion','carnivorous','viviparity','isothermal','mammal'],
       ['cow','hervivory','viviparity','isothermal','mammal'],
       ['lizard','carnivorous','oviparity','poikilothermal','reptile'],
       ['java sparrow','hervivory','oviparity','poikilothermal','bird'] ]

d = defaultdict(int)
for dat in data:
    d[dat[-1]] += 1
px_c = {}
for key in d.keys():
    px_c['p_'+key+'(C)'] = d[key] / float(len(data))
#print px_c

MC = 0
for p in px_c.iteritems():
     MC += -1.0 * p[1] * math.log(p[1],3) 
print "MC = ",MC

c = defaultdict(int)
#for i in range(1,len(data[0])-1): #このへんの繰り返しをあとで考える
    #print i #a_iは1から3まで
for i in range(len(data)):  #a1(食性について)
    c[data[i][1]] += 1
#print c
px_c2 = defaultdict(list)
i = 0
for key in c.keys():
    e = defaultdict(int) #eを初期化
    for dat in data:
        if key in dat:
            e[dat[-1]] += 1
    px_c2[key].append(e)
#print px_c2
M = []
for p in px_c2.iteritems():
    Mi = 0
    for pp in p[1][0].iteritems():
        tmp = pp[1] / float(len(p[1][0]))
        Mi += -1.0 * tmp * math.log(tmp,3)
    M.append( (Mi, len(p[1][0])))
#print M
bunshi = 0
total = 0
for m in M:
    bunshi += m[0] * m[1]
    total += m[1]
M1 = MC - bunshi / (total*1.0)
print "M1 = ",M1

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証

トラックバック - http://d.hatena.ne.jp/seinzumtode/20120427/1335508261
リンク元