Hatena::ブログ(Diary)

驚異のアニヲタ社会復帰への道

Prima Project

2018-05-22

線形回帰の最小二乗法をベクトル偏微分で解く

||y-Ab||^2 の最小化を偏微分で求めるが、ベクトル演算ベクトル微分が「これは知ってて当然でしょ」という感じでさっくり飛ばされることが多いのでしつこいくらいにひとつずつやる。

下準備

微分されるベクトル変数x, 係数のベクトルa とする。基本的にx またはa とかくと、列ベクトルつまり縦に長いベクトルである。つまり、

x=¥begin{pmatrix}x_1¥¥x_2¥¥¥vdots¥¥x_n¥end{pmatrix}, a=¥begin{pmatrix}a_1¥¥a_2¥¥¥vdots¥¥a_n¥end{pmatrix} である。縦に書くと長いので、x もしくは転置T して

x=(x_1,x_2,¥dots,x_n)^T, a=(a_1,a_2,¥dots,a_n)^T と書く。

ベクトル演算として,

a^Tx=x^Ta

である。

転置は(XY)^T=Y^TX^T である。

偏微分¥frac{¥partial}{¥partial x}a^Tx=a=¥frac{¥partial}{¥partial x}x^Ta である。

二次形式という表現x^TAx というものがある。これは、ただ単純にノルムを考えるとa^Ta (普通の二乗和)になるが,ベクトル表現のときに正方行列X を間にかませると、x^2, y^2, xy, x, y の組み合わせを得ることができる。

二次形式の微分は、A のひだりとみぎでそれぞれ微分するから、

¥frac{¥partial}{¥partial x} x^TAx=¥frac{¥partial}{¥partial x} x^T¥cdot Ax+x^A¥frac{¥partial}{¥partial x} x

Ax=M, x^TA=N とおけば、¥frac{¥partial}{¥partial x} x^TAx=¥frac{¥partial}{¥partial x}x^TM+¥frac{¥partial}{¥partial x}Nx=M+N^T

M+N^T=Ax+(x^TA)^T=Ax+A^Tx=(A+A^T)x となる。


というわけで、E=||y-Ab||^2 を展開して、偏微分で0 になるときのb を求めにかかるが

E=(y-Ab)^T(y-Ab)

普通に二乗の形にするが、ベクトル(というか行列表示)ではこうする

E=(y^T-(Ab)^T)(y-Ab)

転置を中にいれた

E=(y^T-b^TA^T)(y-Ab)

転置の公式

E=||y||^2-y^TAb-b^TA^Ty-b^TA^TAb

展開する

E=||y||^2-(Ab)^Ty-b^TA^Ty-b^TA^TAb

a^x=x^a を思い出すと、ここで、y^TAb をかたまりとみて転置する。

E=||y||^2-b^TA^Ty-b^TA^Ty-b^TA^TAb

転置の公式を使って(Ab)^T=b^TA^T とした

E=||y||^2-2b^TA^Ty-b^TA^TAb

まとめた

 

さてこれをb偏微分する。||y||^2b に関係ないので0 になるので

¥frac{¥partial}{¥partial b}b^TA^Ty=A^Ty

b^T(A^Ty) と思えば、¥frac{¥partial}{¥partial x}x^Ta=a である

¥frac{¥partial}{¥partial b}b^TA^TAb=(A^TA+(A^TA)^T)b

A^TA=B とみなせば¥frac{¥partial}{¥partial b}b^TBb=(B+B^T)b である

転置の公式(A^TA)^T=A^TA であり、

¥frac{¥partial}{¥partial b}b^TA^TAb=(A^TA+(A^TA)^T)b=2A^TAb

というわけで、

¥frac{¥partial}{¥partial b}E=-2A^Ty-2A^TAb

偏微分で0 となるときに求めるb だから

b=(A^TA)^{-1}A^Ty

が答え。


nr <- 20

nc <- 7

y <- runif(nr)

X <- matrix(runif(nr*nc), nr, nc)

lm(y ~ X - 1)

Call:
lm(formula = y ~ X - 1)

Coefficients:
      X1        X2        X3        X4        X5        X6        X7  
 0.01803   0.43939   0.28782   0.26241  -0.13173   0.02898   0.35624  
solve(t(X)%*%X)%*%t(X)%*%y
            [,1]
[1,]  0.01803270
[2,]  0.43938587
[3,]  0.28782293
[4,]  0.26241033
[5,] -0.13172939
[6,]  0.02897629
[7,]  0.35623832

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証

トラックバック - http://d.hatena.ne.jp/MikuHatsune/20180522/1526979559