Hatena::ブログ(Diary)

驚異のアニヲタ社会復帰への道

Prima Project

2017-02-04

薬の副作用の有無をAIが100% 的中

読んだ。

Machine learning-based prediction of adverse drug effects: an example of seizure-inducing compounds

プレスリリース

 

COI:なし

薬学も電気生理学情報科学も何ひとつ専門ではないけど、100%予測に釣られたので書く。

 

seizure (けいれん)を引き起こす副作用のある薬物を、Caffe を用いたdeep learning とSVM による機械学習で100% 予測しました、というweb ニュースがでているが、実際には、論文でけいれんが報告されているジフェンドラミン、エノキサシン、ストリキニーネ、テオフィリンの4剤とけいれんを絶対に起こすピクロトキシン1剤、けいれんの副作用はないということになっているアスピリン、シメチジン、デキストラン、ジアゼパムイブプロフェン、イミプラミン、ケタミンメタンフェタミン、オセルタミビルの9剤を判別するタスクをしている。

 

ここで、Mg の話だが、Mg細胞膜の電気的活動の安定性に関わっており、脳脊髄液(CSF) では1.2nM らしいが、0.1nM まで下げると勝手にseizure-like event SLE が起きやすくなるらしい。以下の実験ではMg 濃度は0.1nM でやっており、SLE は起きやすい状況になっているようである(Fig.1)。

 

クロトキシンの話だが、GABAという物質があって、これは神経接合部での神経活動の惹起を抑える働きがあるが、ピクロトキシンはこのGABAの働きを抑えることで、結果として興奮が起きやすくなる、ということで、ピクロトキシンの濃度を上げていくと確かにSLE が生じてけいれんが起きるということになっている(Fig.2)。

 

マウスから海馬を摘出し、上記の薬剤の臨床的な血中濃度から最大30倍くらいのオーダーの濃度まで5段階で適宜、脳脊髄液(CSF)還流液を調整して、SLE の信号が測定されるか観測している(Table.1)。その信号は、ピークとして0.1uV/ms を超えた前後 -200 から2070ms までを切り出して、227x227 ピクセルの画像にして、caffe に投げるということらしい。ということで、入力データは5濃度、14薬剤の70画像である。

 

ひとまずSLE の起きる様子をベタにプロットすると、けいれんを起こすと言われている次フェンドラミン、エノキサシン、ストリキニーネ、テオフィリンは濃度依存性にSLE の発生が多かった(Fig.3)。ケタミンメタンフェタミン、オセルタミビルでは電気活動があるように見えるけれども、発火がオシレーション様ではないのでSLE とは言わないらしい。これは専門でないのでよくわからん。

 

肝心の判定性能だが、総計70枚の、227x227ピクセルの画像を、2012年のLarge Scale Visual Recognition Challenge で事前に学習された6層判別器に入力している。結果として4096の長さの特徴量ベクトル(70*4096 行列)ができて、それをPCA することでごしゃっとデータをまとめている。PC1(37%)とPC2(21%)の軸だけ取り出し、SVM (線形)をやると、けいれんを起こす薬剤で、SLEを引き起こす濃度(ここがよくわからなかったがFig.4)の空間が線形分離できる、これが判別性能100% と言っている。

一応、LOOCV で1:13 の交差検証をやっても100% だし、けいれんを起こすことがわかっているイソニアジドとメトクロプラミドを判別しても、陽性判定だったので、外的妥当性も検証している、と言っていそう。

PCA空間をみると、けいれんを起こす薬剤のなかでジフェンドラミンだけが別クラスターになっており、これはSLE を見ると下がって上がるパターン(Fig.4)であり、ほかの4剤の上がって下がるパターンとは異なるので、おそらく薬理学的な違いが出ていて、これを考えるうえでも役に立ちそう、と言っている。

 

感じたこととして、ド素人だけど情報科学生物学医学の見地からまとめてみる。

 

情報科学の見地からは、まず、時刻t で観測され続けているであろうSLE の時系列データE(t) を、画像にしてしまうことにいったいどれほどの意味があるのだろうか。現象的には、ベクター画像をラスター画像にして解析しているような感覚だが、それでもけいれんの有無を判別できるというのは、それはそれでこのやり方はできないことはないのだろう。なんとなく、音声信号をスペクトログラムにしてdeep learning にぶちこめば例えば声帯結節あり/なしがわかる、みたいなゴリ押し感はある。

ただし、これは「見た目が違っていて、なんとなく違いが判別できる」という定性的な概念になってしまうので、できれば、定量的にSLE のスパイクの形成具合のどことどことで数値的にどうなったらこれはてんかん、と言いたい感じはある。また、事前に学習したモデルって、背景とか馬とかだと思うけど、SLE シグナルみたいな「ただの線」にその学習モデルを使うというのは許容されているアルゴリズムなんだろうか。

 

入力が70ってどうなんですかね。また、caffe で4096の長さの特徴量ベクトルを作ったならそのままdeep learning で判別器作ればいいと思うし、PCA で次元を落として、って、それならば画像データにして無駄に次元を変えたりデータ量を変えたりして削減するという流れがいまいちよくわからなかった。有名な論文みたいだけど、わざわざ問題を難しく解かなくても…という気はする。プレリミ的には全次元で判別したらSVM でうまく分離できなかったらしい。PCA でノイズが減るのは減ると思うけど、難しいことをしすぎではないか。

「けいれんを起こす薬剤」という判定基準が雑な気がする。ジフェンドラミン、エノキサシン、ストリキニーネ、テオフィリンは論文ベースで探したようだが、例えばケタミンでも報告はあるし、何を基準に論文サーベイしたのかわからなかった。普通に考えて、精神に作用するメタンフェタミンがけいれんを起こさない薬物に最初から分類されているのが(薬学は専門でないので)よくわからなかったし、ケタミンメタンフェタミンの2剤がSLE を起こしていないというのも(生理学は専門でないので)よくわからなかった。究極的に意味がわからなかったのが、イミプラミンは0.3-0.6% で臨床的にけいれんを起こす、と書いてあるのに、SLE を起こさなかったのでけいれんを起こさない薬剤と判定していることだった。

けいれんを起こすか、Yes/No の二択での判定の限界というか別にいい手法はあるのか、というのは議論で指摘されていた。

 

生物学的な見地からは、おそらく、どんな薬物でも高濃度にすれば、けいれんというか細胞の異常発火は引き起こしうる。今回の濃度設定は、ヒトの臨床的な血中濃度の報告を元に、それをいい感じで含むように濃度調整をして、マウスで実験している。設定した濃度の範囲外でSLE が生じる可能性があることは議論でも述べられているが、この手の毒性/副作用予測論文ではやはり、ヒトとモデル生物間での飛躍がどうしても避けられない。例えば細胞レベルやマウスレベルのでの毒性試験がヒトでも毒性があるかというとそうでもないことは多いし、「けいれん予測」が「その薬剤を内服したヒトで起こるか」ではなく、「マウス海馬を溶液にひたしてSLE が起こるか」ということの予測であることは注意しておくべきだろう。

あと、マウス海馬を切って浸してSLE を測定するというこの実験系、ハイスループットにできるのか謎だった。

 

医学的な見地からは、メトクロプラミドがけいれんを起こす薬剤の外的妥当性の検証として使われているが、この薬、けいれんを起こすと判定されたから危ない薬かというと、死ぬほど処方されているプリンペラン(吐き気止め)である。添付文書を読んでみるとたしかにけいれんは書いてあるのだが、神経学的副作用には頻度すら書いていない。では世の中の医者がけいれんのリスクを認識してメトクロプラミドを処方しているかというと、たぶんしていない。クスリはリスクというくらいではあるが、リスクベネフィットを勘案してクスリは処方されるわけで、元論文をみてもoverdose の例らしいので、このあたり、情報科学的見地でも述べたように、陽性例の設定基準があやふやと言わざるをえない。

 

(元論文を見たが、「イソニアジド(抗結核薬)とメトクロプラミドは両者ともけいれんを起こしうると報告されている」と引用されていたのに、引用にはイソニアジドのことしか書かれていなくてメトクロプラミドをgrep したけど書いていなかったようである、読んでいないのでもしかしたら別名で書いてあるのかもしれない。また、メトクロプラミドの濃度を決めたという論文も抄録ではメトクロプラミドがけいれんを引き起こすことは書いていない。本文は知らない)

(やっぱり気になって元論文を読んだが、17歳男性が失恋時に30錠くらいのイソニアジドを飲んだ時の管理という症例報告で、イソニアジドの肝機能障害は有名だけどけいれんや昏睡などの精神症状に気をつけてね、という話で、補助療法で活性炭など使うといいよとあったが消化管症状など読んでもメトクロプラミドは出てきていなかったし、メトクロプラミドとけいれんの話はどこに書いてあるのかわからなかった。)

 

なのでここでも、「in virto でAI がけいれんが起こると判定する」ことと、「実臨床でけいれんが起こる」ことがびっくりするくらい乖離している。また、「けいれんが起こった薬剤」というのは、因果関係がわからないけれども、「少なくとも投与した後にけいれんが生じた」場合には、珍しい/重症な副作用は報告する義務があるので、けいれんを起こすという薬剤が本当にけいれんを起こしているのかは実臨床レベルでは不明な事が多い(と思う)。

また、薬物単体で副作用を起こすだけではなく、相互作用で起きる場合や、上の例でいれば「相互作用で起きた、もしくは犯人でもないのに巻き添えで犯人扱いされている」場合もないことはないと思われるので、けいれんを起こす/起こさない薬剤の選定は非常に難しい、はず。

 

という感じで、モヤモヤ感がはんぱなかった。

 

見かけたツイート

前者はそうだけど、後者は必ずしも(現実問題には)当てはまりません。例えば正規分布は、取りうる値は(-¥infty,¥infty)であるが、現実的には99.9% くらいの区間で適当な有限値になるはず。仮に、「新生児と20歳の男性を身長で判別する」問題を考えるとき、新生児はたかだか平均50cm に適当な分散なので、20歳男性の身長からは100% 判別されるはず。もちろん、20歳男性が例えば低形成症だとしても、普通にサンプリングすれば平均170cm に適当な分散になるので、新生児は実データ上は完全に分離される。もちろん、170cm と入力すべきところを1.7m と勘違いして1.7 と入力したり、本当にすごい症例報告レベル(確率10^{-1000}くらいな)の身長がいたりすれば別だけど。ということで、今回のけいれんで言えば、「けいれん薬剤のSLE は、けいれんでない薬剤のSLE とはもう本当に違いうる」ということがあれば、100% 判別ニキでもおかしくはない。これは結局安全域が狭い薬剤なのか広い薬剤なのかという話だと思うけど。

ただ、けいれん有無の薬剤の選び方とか、SLE の判定とかよくわからんことが多いので、交差検証とかしてモデルの妥当性を解析していても、モヤモヤ感は残る。

 

既存の薬剤で、副作用としてけいれんがある、と言われているものを判別しています。ただ、2つだけで、どちらも臨床的には普通に使われている薬剤なので、何をもってけいれんの副作用を言っているのかはちょっとよくわかりませんでした。

2017-01-21

機械学習プロフェッショナルシリーズを何冊か読んだ

読んだ。

関係データ学習 (機械学習プロフェッショナルシリーズ)

関係データ学習 (機械学習プロフェッショナルシリーズ)

ラボにあった。

関係データということでテンソル分解みたいなことから始まるのかなと思ったら、グラフ関係のスペクトルクラスタリングから始まってて、ふんふんと読んでいた。

中盤ではinfinite ralational model推しだった。Chinese restaurant process を利用して、もともとある卓に着席するか、新たに卓を設けるかを最適化すると、クラスター数が勝手に定まるのでkmeans のようにユーザーがクラスター数を決めるのに悩まくていいでしょ? って感じが売りだった。

後半はテンソルの話だった。

関係データ解析とInfinite Relational Model - Qiita

PDF

 

ラボにあった。

正則化といった流行りのスパース性の話だった。冒頭ではBias-variance 分解について述べてモデルの汎化性能をきちんとしましょう的なことが書いてあったし、序盤では正則化によって解の範囲がどう制限されるか、そしてL1やL2 で全体の性能がどうなるかとか、0 に落ち着くパラメータがどれくらいになるかといったシミュレーションの結果が出ていた。

中盤では正則化問題をどう最適化するかを数式でゴリゴリしていたのでちょっとついていっていないが、グループ正則化、trace 正則化、アトミック正則化ロバストPCA など知らなかった手法が紹介されていたので勉強になった。

生物医学界隈でもデータが大きくなって、全遺伝子はデータとったけどサンプル数が足りない、いわゆるp>>n 問題になるので、

p167:次元dがサンプル数nよりもずっと大きい学習/推定問題を考えている

から頭の片隅に置いとくべきである。また、GWAS みたいに多因子疾患を考えるとき、いくつかのSNP がちょっとずつ関与してなんとなくもやっと

p167:予測性能だけでなく、なぜ予測できるのかを説明できることが重要である

という感じで使いそう。

 

変分ベイズ学習 (機械学習プロフェッショナルシリーズ)

変分ベイズ学習 (機械学習プロフェッショナルシリーズ)

数式が追い切れなくて消化不良。

2017-01-19

MikuHatsune2017-01-19

性なる夜に

こんなツイートがあった。

日本人の生年月日の数から順位付けしたらしい。

この画像をみたときに、出生数の月日分布なんてデータとしてあるの? と思っていたら、estat の人口動態データとしてあった。

このうち、出生数,出生年月日時・出生の場所別 というデータが、月日時刻別に集計しているのでこれを整形する。なぜか1995年がPDF だったり、csv も全角数値が入っていたり行数が年度ごとに揃っていなかったりとネ申エクセルではないけどデータ管理の下手くそさを感じさせるが、せっかくのデータを公開してくれているので文句は言わない(言うけど。

これをPython で整形して、2015年から1996年まで、1月1日から12月31日までデータを取ってくる。

# ターミナルで文字コード変換
for i in `ls *csv`; do nkf -w --overwrite  $i; done

# Python
import os
import glob
import re
import codecs
input_wd = "/birth/"

files = glob.glob(input_wd + "*csv")
digit = re.compile("\d+")

month = re.compile(".月")
month = "月"

w0 = open("birth_estat.csv", "w")
for f in files:
  g = open(f, "rU")
  res = []
  hoge = ""
  while not (month in hoge and "総" in hoge):
    hoge = g.readline()
    print hoge
  while "病" not in hoge:
    hoge = g.readline()
    if "日" in hoge:
      res += [ hoge.rstrip().split(",")[1] ]
  year = digit.findall(f)
  dat = year + res
  text = ",".join(dat) + "\n"
  w0.write(text)

w0.close()

20年の間に年間出生数は年々減ってきているため、比率にして全体を揃える。

365日間の、全出生数に対する各日の出生比率だが、下向きのスパイクがいくつかある。これはツイートでも言われているように、ハッピーマンデーに寄らない祝日である。各年度は1周間単位で振動しているが、これは週末に出生数が少ないことによるもので、これを20年分重ねれば週末分は消える。これでも残る成分は、例えば2月11日だったり、4月1日だったり、年末正月、GW、お盆である。

f:id:MikuHatsune:20170118221444p:image

各年度の出生比率分布をカレンダープロットすると、国民休日や、特に週末日曜日に出生比率が低いことがわかる。出生は自然に生まれるパターンと、緊急で帝王切開で生まれるパターンと、予定手術の帝王切開で生まれるパターンがあるが、近年は帝王切開が増えていることと、帝王切開を予定手術でやるならば、普通は、人手の多い平日にやろうというのが普通なので、祝日週末は避ける。

4月1日も学年の変わり目になるので普通は避けるだろう。

f:id:MikuHatsune:20170118221926g:image

 

さて、クリスマスになると、声優監視スレがあるように、クリスマスに何が行われているかというと、ナニである。ここで、出生日と出生までの妊娠週数データがあるので、これからナニの日を推定したい。いま、妊娠週数は厚生労働省のデータを取ってきて、満N週になっているので線形に補間して日数にする。

38-39週が最も多いが、妊娠最小限界の22週から、過産期の43週までこんな感じになる。

厳密には、妊娠週数の数え方は最終月経からということになっているが、生物学的な胎児の週数は受精してからになるので、2週間のずれが生じる。ここでは、受精とナニは同時にあったとして2週間引いている。

f:id:MikuHatsune:20170118222359p:image

 

出生月日から単純に上の週数をランダムサンプリング(ただし-2週)して、生まれた児の受精日を逆算したら、それつまりナニの日ということをやってみる。20年間の出生数データがあるので、最初と最後以外の18年分をやってみる。

結果としては、年末年始が最も多いようだった。6-8月の夏は少なめ。意外と開放感がないが、5月のGW っぽいところでヒト山あるので、連休中はアレなのだろう。

f:id:MikuHatsune:20170118222917p:image

 

同じようにカレンダープロットを作った。出生の生データでは週末や祝日の影響があったが、妊娠週数の分布データからサンプリングするとこれは消えて、連続する日はほとんど滑らかになる。

f:id:MikuHatsune:20170118222916g:image

 

結局のところ、妊娠38週(胎児は36週)が最も頻度が高いので、9月に出生数が多いことから単純に36週引いたら年末年始になる、という簡単な結果になった。年末年始はもちろんクリスマスを含むので、まあアレだね!

 

本当ならば、出生以外にも、残念ながら死産人工中絶というものがあるので、本来の受精数の365日ベクトル¥theta というものを念頭に置いて、死産中絶の分布を考慮して真のナニ分布というものを推定したかった。特に中絶については、この20年で40-20万件あり、出生数に対して20% 程度を占めることから、真の分布に及ぼす影響は小さくないだろうと思われたが、件数と週分布はあっても月日分布はさすがに自治体、病院レベルでも見当たらなかったので、出産によるデータのみによって推定せざるを得なかった。

あとは、出生120万/年くらいとして、図では最大と最小の差でも0.0004 くらいしかないが、これは400-500くらいの差でしかない。365日均一だとしても1/365=0.0027 なので、絶対的な差としてはそんなにない。

 

結論:命は愛おしい

 

# 妊娠週数
bw <- 22:43
bwn <- c(38, 68, 106, 130, 142, 196, 208, 259, 265, 268, 351, 444, 620, 767, 1399, 4103, 10002, 16123, 14954, 7324, 798, 29)
pbw <-bwn/sum(bwn) 

# 妊娠日数
pbd <- rep(pbw, each=7) + rep((c(pbw[-1], 0) - pbw)/7, each=7)*(0:6)
pbd <- pbd/sum(pbd)
bd <- c(t(outer(bw*7, 0:6, "+")))
plot(bd, pbd, type="o", pch=16, lwd=2, xlab="Birth day", ylab="Probability")
abline(v=bw*7, lty=3)
text((c(bw[-1], max(bw)+1)*7+bw*7)/2, par()$usr[4], bw, xpd=TRUE, pos=3)

# estat を処理したデータ
dat <- read.csv("birth_estat.csv", header=FALSE, row=1)
dat <- t(dat)
dat <- dat[!apply(dat == "-", 1, all),]
dat[dat == "-"] <- NA
dat <- apply(dat, 2, as.numeric)
days <- c(31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
md <- unlist(mapply(rep, 1:12, days))
dd <- unlist(mapply(seq, days))
ymd <- t(outer(colnames(dat), paste(md, dd, sep="-"), paste, sep="-"))

# うるう年の処理
ymdate <- mapply(function(z) as.Date(ymd[!is.na(dat[,z]), z]), 1:ncol(ymd))
datlist <- apply(dat, 2, na.omit)

cols <- rainbow(ncol(dat))
datmedian <- sweep(dat, 2, colSums(dat, na.rm=TRUE), "/")
matplot(datmedian, type="l", xaxt="n", lty=1, lwd=0.6, col=cols, xlab="Month", ylab="Median birth", xlim=c(1, nrow(datmedian)+50))
lines(apply(datmedian, 1, median, na.rm=TRUE), lwd=3)
month <- c(1, cumsum(days))
abline(v=month, h=1/365, lty=3)
text((head(month, -1)+tail(month, -1))/2, par()$usr[3], 1:12, xpd=TRUE, pos=1)
legend("right", legend=colnames(dat), col=cols, lty=1, lwd=3, bg="white", bty="o", box.col=NA)

for(i in seq(ymdate)){
  plotdat <- data.frame(date=ymdate[[i]], Birth=datlist[[i]]/sum(datlist[[i]]))
  calendarPlot(plotdat, pollutant="Birth", year=names(datlist)[i], cols="jet")
}

# list を vector にする
alldate <- as.Date(unlist(ymdate), origin="1970-1-1")
allbirth <- unlist(datlist)
years <- as.numeric(unique(format(alldate, "%Y")))

# 妊娠日数からランダムサンプリングしてナニの日を推定するシミュレーションを1回行う関数
sim <- function(){
  res <- vector("list", length(alldate))
  for(i in length(alldate):360){
    diffday <- sample(bd-2*7, size=allbirth[i], prob=pbd, replace=TRUE)
    res[[i]] <- alldate[i] - diffday
  }
  res1 <- as.Date(unlist(res), origin="1970-1-1")
  tab <- table(res1)
  d2 <- as.Date(names(tab))
  tab <- as.numeric(tab)
  names(tab) <- NULL
  dat2 <- data.frame(date=d2, birth=tab)
  years <- as.numeric(unique(format(alldate, "%Y")))
  # カレンダープロット用に年度ごとにまとめる
  dat3 <- mapply(function(y){
    idx <- d2 >= as.Date(paste0(y, "-1-1")) & d2 <= as.Date(paste0(y, "-12-31"))
    dat22 <- dat2[idx,]
  }, years[2:(length(years)-1)], SIMPLIFY=FALSE)
  return(dat3)
}

# 1回あたりにかなり時間がかかる
iter <- 50
result <- vector("list", iter)
for(i in seq(iter)){
  result[[i]] <- sim()
}
result_date <- mapply(function(z) z$date, result[[1]])
result_birth <- mapply(function(i) mapply(function(z) z[[i]]$birth/sum(z[[i]]$birth), result), 1:(length(years)-2), SIMPLIFY=FALSE)

# うるう年の処理
tmp <- lapply(result_birth, apply, 1, median)
for(i in seq(tmp)){
  if(length(tmp[[i]]) < 366){
    tmp[[i]] <- c(tmp[[i]][1:59], NA, tmp[[i]][60:length(tmp[[i]])])
  }
}

for(i in seq(tmp)){
  if(length(tmp[[i]]) < 366){
    tmp[[i]] <- c(tmp[[i]][1:59], NA, tmp[[i]][60:length(tmp[[i]])])
  }
}
tmp <- do.call(cbind, tmp)

cols <- rainbow(ncol(tmp))
matplot(tmp, type="l", xaxt="n", lty=1, col=cols, xlab="Month", ylab="No. estimated fertilization", xlim=c(1, nrow(tmp)+50))
month <- c(1, cumsum(days))
abline(v=month, h=1/365, lty=3)
text((head(month, -1)+tail(month, -1))/2, par()$usr[3], 1:12, xpd=TRUE, pos=1)
legend("right", legend=years[2:(length(years)-1)], col=cols, lty=1, lwd=3, bg="white", bty="o", box.col=NA)

for(i in 1:ncol(tmp)){
  plotdat <- data.frame(date=result_date[[i]], Fertilization=apply(result_birth[[i]], 1, median))
  calendarPlot(plotdat, pollutant="Fertilization", year=years[i+1], cols="jet")
}

2017-01-18

指数分布族

変数x, パラメータ¥theta があるとき、

¥begin{align}f(x|¥theta)&=h(x)g(¥theta)¥exp¥{¥eta(¥theta)T(x)¥}¥¥&=h(x)¥exp¥{¥eta(¥theta)T(x)-A(¥theta)¥}¥¥&=¥exp¥{¥eta(¥theta)T(x)-A(¥theta)+B(x)¥}¥end{align}

と表すことができる場合は、指数分布族という。ここで、exp やlog で無理やりカッコ内に指数法則を使って変換してやれば、各々の式は同じことを意味している。

ここで、

T(x):十分統計量 sufficient statistics

¥eta(¥theta):natural parameter

指数分布族のwiki にはどんな分布が指数分布族で、統計量がどういうものかが書いてある。

 

ここでは、ベルヌーイ分布が指数分布族である例を出していて、

¥begin{align}p(x|¥theta)&=¥theta^x (1-¥theta)^{1-x}¥¥&=¥exp¥{x¥log¥theta+(1-¥theta)¥log(1-¥theta)¥}¥¥&=(1-¥theta)¥exp¥{x¥log¥frac{¥theta}{1-¥theta}¥}¥¥&=¥exp¥{x¥log¥frac{¥theta}{1-¥theta}-(-¥log(1-¥theta))¥}¥end{align}

となるので、

T(x)=x

¥eta(¥theta)=¥log¥frac{¥theta}{1-¥theta}

A(¥theta)=-¥log(1-¥theta)

である。

十分統計T(x) の期待値や分散を求めることができて、こちらこちら(PDF)を使って例えばベルヌーイ分布の平均は

¥begin{align}E(T(x))&=¥frac{A^{’}(¥theta)}{¥eta^{’}(¥theta)}¥¥&=¥frac{¥frac{1}{1-¥theta}}{¥frac{1}{¥theta}+¥frac{1}{1-¥theta}}¥¥&=¥theta¥end{align}

となって¥theta であることが確認できる。' は微分である。

 

情報幾何の先生といろいろやっていて、分布をlog で変換した時にたぶんこれが有用ようなという雰囲気。

2017-01-14

MikuHatsune2017-01-14

Bias-Variance decomposition

機械学習などで予測モデルを立てたときに絶対に突っ込まれるのが「それ他のデータセットでも言えんの?」ということで、モデルの性能を評価しないといけない。一般的には手持ちのデータからモデルを作って、それとは別のデータに対してそのモデルが有用かという汎化性能をみないといけない。しかしながら、モデルだけ作って丸投げとか、モデルを作ったデータに再度モデルを適応して性能を評価したとか、そういうのがほとんど散見される。

 

Bias-Variance decomposition (日本語)によって、あるモデルに寄る誤差は

¥begin{align}E¥[(Y-¥hat{f}(x))^2¥]&=¥sigma^2 + (E_T¥[¥hat{f}(x)¥]-f(x))^2 + E_T¥[(¥hat{f}(x)-E_T¥[¥hat{f}(x)¥])^2¥]¥¥&=¥textrm{irreducible}~¥textrm{error} + ¥textrm{Bias}^2 + ¥textrm{Variance}¥end{align}

とかける。このとき、bias というのはモデルと訓練データから計算できる、モデルの説明具合で、variance というのはも出るとテストデータから計算できる、期待値まわりの分散である。

パラメータ1個のような単純なモデルだと、そもそも手持ちの訓練データを表現することは(たいていの場合)不可能であるが、新しいデータが入ってきても「まあそんなもんだよね」という程度の性能で新しいデータを表現できる。すなわち、bias が大きくてvarianve が小さい、という状況で、underfitting というらしい。

一方で、パラメータが超多いようなモデルでは、手持ちの訓練データをほぼ完璧に表現できる(n個の未知変数はn個の連立方程式が必要というアレ)一方で、手持ちの訓練データに(のみ)ガチガチに当てはまっているから、新しいデータが入力されたときにものすごい外れた値が出力される。すなわち、bias が小さくてvariance が大きい、という状況で、これが巷でよくきくoverfitting である。

 

Bias-Variance decomposition や、それを図示するvalidation curve というのは、機械学習に強い(とよく言われる)Python では実装されているようで、よく見かける。しかし、R では素人でも一発でできるような実装はざっと調べたかぎりではないようで、R-blogger でも2016年の記事があった程度なので、写経してやっておく。

 

いま、多項式で適当に発生させたデータがあるとする。真のモデルは3次式(切片含めて自由度4)で、ランダムエラー¥epsilon を入れて

y=X¥beta+¥epsilon

でデータが発生している。ここで、X は1次元から30次元までの多項式で、x は[0,100] で適当に発生させたものである。¥beta多項式の係数。

f:id:MikuHatsune:20170114111759p:image

The Elements of Statistical Learning の図7.1 では、bias とvariance のシュミレーションがあるので、これを再現してみる。原著はPDF があるのでやる気のある人は参照してほしい。

 

シミュレーションでは50個のデータ点(訓練データ)でモデルを作り、10000個の新規データ(テストデータ)での性能を100回検証している。

灰色は訓練データへのbias である。横軸は多項式の次数の多さで、みぎに行くほど複雑なモデルになる。そのため、訓練データをほとんど完璧に記述できる(エラーが小さくなる)ようになる。これだけみると、モデルの次元は大きく複雑なほうがいいように思える。

赤色はテストデータへの当てはまりである。単純なモデルから次数を増やしてくと、あるところで最小値をとり、その後エラーが増大するようになる。オレンジの点はテストエラーが最小になるときの次数であり、このとき5だったようである。

 

このデータがそもそもどのように出来たかは、我々は(Rコンソール内では)神なので、4次元だと知っている。4次元には近いがちょっと違う。ただし、それは我々が神だから知っているだけの話なので、現実的には最小となればいいだろうと思って採用する。

f:id:MikuHatsune:20170114111801p:image

 

R-blogger のほうではクロスバリデーションを使ったhyperparameter の選択とか書いてある。

# Generate the training and test samples
seed <- 1809
set.seed(seed)
gen_data <- function(n, beta, sigma_eps) {
    eps <- rnorm(n, 0, sigma_eps)
    x <- sort(runif(n, 0, 100))
    X <- cbind(1, poly(x, degree = (length(beta) - 1), raw = TRUE))
    y <- as.numeric(X %*% beta + eps)
    return(data.frame(x = x, y = y))
}

# Fit the models
require(splines)
n_rep <- 100
n_df <- 30
df <- 1:n_df
beta <- c(5, -0.1, 0.004, -3e-05)
n_train <- 50
n_test <- 10000
sigma_eps <- 0.5

xy <- res <- list()
xy_test <- gen_data(n_test, beta, sigma_eps)
for (i in 1:n_rep) {
    xy[[i]] <- gen_data(n_train, beta, sigma_eps)
    x <- xy[[i]][, "x"]
    y <- xy[[i]][, "y"]
    res[[i]] <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf)))
}

# Plot the data
x <- xy[[1]]$x
X <- cbind(1, poly(x, degree = (length(beta) - 1), raw = TRUE))
y <- xy[[1]]$y
plot(y ~ x, col = "gray", lwd = 2, pch=16)
lines(x, X %*% beta, lwd = 3, col = "black")
lines(x, fitted(res[[1]][[1]]), lwd = 3, col = "palegreen3")
lines(x, fitted(res[[1]][[4]]), lwd = 3, col = "darkorange")
lines(x, fitted(res[[1]][[25]]), lwd = 3, col = "steelblue")
legend(x = "topleft", legend = c("True function", "Linear fit (df = 1)", "Best model (df = 4)", 
    "Overfitted model (df = 25)"), lwd = rep(3, 4), col = c("black", "palegreen3", 
    "darkorange", "steelblue"), text.width = 32, cex = 0.85)

# Compute the training and test errors for each model
pred <- list()
mse <- te <- matrix(NA, nrow = n_df, ncol = n_rep)
for (i in 1:n_rep) {
    mse[, i] <- sapply(res[[i]], function(obj) deviance(obj)/nobs(obj))
    pred[[i]] <- mapply(function(obj, degf) predict(obj, data.frame(x = xy_test$x)), res[[i]], df)
    te[, i] <- sapply(as.list(data.frame(pred[[i]])), function(y_hat) mean((xy_test$y-y_hat)^2))
}

# Compute the average training and test errors
av_mse <- rowMeans(mse)
av_te <- rowMeans(te)

# Plot the errors
plot(df, av_mse, type = "l", lwd = 2, col = gray(0.4), ylab = "Prediction error", 
    xlab = "Flexibilty (spline's degrees of freedom [log scaled])", ylim = c(0, 1), log = "x")
abline(h = sigma_eps, lty = 2, lwd = 0.5)
for (i in 1:n_rep) {
  lines(df, te[, i], col = "lightpink")
}
for (i in 1:n_rep) {
  lines(df, mse[, i], col = gray(0.8))
}
lines(df, av_mse, lwd = 2, col = gray(0.4))
lines(df, av_te, lwd = 2, col = "darkred")
points(df[1], av_mse[1], col = "palegreen3", pch = 17, cex = 1.5)
points(df[1], av_te[1], col = "palegreen3", pch = 17, cex = 1.5)
points(df[which.min(av_te)], av_mse[which.min(av_te)], col = "darkorange", pch = 16, cex = 1.5)
points(df[which.min(av_te)], av_te[which.min(av_te)], col = "darkorange", pch = 16, cex = 1.5)
points(df[25], av_mse[25], col = "steelblue", pch = 15, cex = 1.5)
points(df[25], av_te[25], col = "steelblue", pch = 15, cex = 1.5)
legend(x = "top", legend = c("Training error", "Test error"), lwd = rep(2, 2), 
    col = c(gray(0.4), "darkred"), text.width = 0.3, cex = 0.85)

ここで、deviance 関数というのがあるが、これはbias を計算していて、fit という回帰後のオブジェクトがあったとすれば

sum(fit$residuals^2)

もしくは

sum((y - fit$fitted.values)^2)

で求められる。

 

Python ではおなじみだが、入力数が増えると判別性能がどう改善されていくかというlearning curve というのがR ではおなじみではなさそう。caret パッケージでできるようだ。

library(caret)
library(kernlab)
data(spam)
spam0 <- spam
lda_data <- learing_curve_dat(dat = spam, verbose=FALSE,
                              outcome = "type",
                              test_prop = 1/5, 
                              ## `train` arguments:
                              method = "lda", 
                              metric = "ROC",
                              trControl = trainControl(classProbs = TRUE, 
                                                       summaryFunction = twoClassSummary))

v <- split(lda_data, lda_data$Data)
yl <- c(0.9, 1)
plot(0, type="n", xlim=c(1, max(lda_data$Training_Size)), ylim=yl, xlab="Training size", ylab="ROC", las=1)
lines(v$Training$Training_Size, v$Training$ROC, type="o", pch=16, col=2, lwd=3)
lines(v$Testing$Training_Size, v$Testing$ROC, type="o", pch=16, col=4, lwd=3)
legend("bottomright", legend=c("Training", "Testing"), col=c(2, 4), lty=1, pch=16, lwd=3)

f:id:MikuHatsune:20170114111802p:image