盆暗の学習記録

データサイエンス ,エンジニアリング,ビジネスについて日々学んだことの備忘録としていく予定です。初心者であり独学なので内容には誤りが含まれる可能性が大いにあります。

[R]ゼロから作る最小二乗法2:重回帰

「自分で数式をコードに落としていって動かす」という作業は非常に勉強になると思ったので,いろんなアルゴリズムをゼロから作っていきたいと思います。

理論の要点を整理してから実装する構成で述べていきます。

モデル

線形回帰(linear regression)は,予測の目的変数\(Y_i\)と説明変数\(X_{i1}, X_{i2}, ..., X_{ip}\)の間に次のような線形関係を仮定したモデルを置いて予測する手法です。 \[ Y_i = \beta_0 + \beta_1X_{i1} + \beta_2 X_{i2} + \cdots + \beta_p X_{ip} + u_i \]

このように説明変数が複数ある線形回帰は重回帰モデル(multiple regression model)とも呼ばれます。

行列を使って表記すると,次のように表現できます

\[ \boldsymbol{y}=\boldsymbol{X} \boldsymbol{\beta}+\boldsymbol{u} \]

ただし,

\[ \boldsymbol{y} = \begin{bmatrix} y_1 \\ y_2 \\ \vdots \\ y_n \end{bmatrix} , \ \boldsymbol{X} = \begin{bmatrix} 1 & X_{11} & \cdots & X_{1p} \\ 1 & X_{21} & \cdots & X_{2p} \\ \vdots &\vdots & &\vdots &\\ 1 & X_{n1} & \cdots & X_{np} \\ \end{bmatrix} ,\ \boldsymbol{\beta} = \begin{bmatrix} \beta_0 \\ \beta_1 \\ \vdots \\ \beta_p \end{bmatrix} , \ \boldsymbol{u} = \begin{bmatrix} u_1 \\ u_2 \\ \vdots \\ u_n \end{bmatrix} \]

誤差関数

\(\boldsymbol{\beta}\)の最小2乗推定量(ordinary least square’s estimator: OLSE)を\(\hat{\boldsymbol{\beta}}\)とすると,モデルからの\(\boldsymbol{y}\)の予測値は\(\hat{\boldsymbol{y}} = \boldsymbol{X}\hat{\boldsymbol{\beta}}\),実測値と予測値の誤差(残差)は\(\boldsymbol{e}=\boldsymbol{y}-\hat{\boldsymbol{y}}\)になります。

パラメータ\(\boldsymbol{\beta}\)を推定するために多く使われる方法は,実測値\(\boldsymbol{y}\)と予測値\(\hat{\boldsymbol{y}}\)誤差2乗和(Sum of Squared Error)1

\[ SSE = \sum_{i=1}^n (y_i - \hat{y}_i)^2 = \sum_{i=1}^n e_i^2=\boldsymbol{e}'\boldsymbol{e} \]

を最小にするパラメータを採用するという最小二乗法(least squares method)です。

パラメータの推定

最小二乗推定量は「誤差2乗和を微分してゼロになる点」という次の条件

\[ \frac{\partial \boldsymbol{e}'\boldsymbol{e}}{\partial \hat{\boldsymbol{\beta}}}=\boldsymbol{0} \]

の解となるものです。

誤差2乗和は

\[ \begin{align} \boldsymbol{e}'\boldsymbol{e} &= (\boldsymbol{y} - \hat{\boldsymbol{y}})'(\boldsymbol{y} -\hat{\boldsymbol{y}})\\ &= (\boldsymbol{y} - \boldsymbol{X}\hat{\boldsymbol{\beta}})'(\boldsymbol{y} - \boldsymbol{X}\hat{\boldsymbol{\beta}})\\ &= \boldsymbol{y}'\boldsymbol{y} - \boldsymbol{y}'\boldsymbol{X}\hat{\boldsymbol{\beta}} - \hat{\boldsymbol{\beta}}'\boldsymbol{X}'\boldsymbol{y} + \hat{\boldsymbol{\beta}}'\boldsymbol{X}'\boldsymbol{X}\hat{\boldsymbol{\beta}}\\ &= \boldsymbol{y}'\boldsymbol{y} - 2\hat{\boldsymbol{\beta}}'\boldsymbol{X}'\boldsymbol{y} + \hat{\boldsymbol{\beta}}'\boldsymbol{X}'\boldsymbol{X}\hat{\boldsymbol{\beta}} \end{align} \]

であるから23

\[ \frac{\partial \boldsymbol{e}'\boldsymbol{e}}{\partial \hat{\boldsymbol{\beta}}} =-2\boldsymbol{X}'\boldsymbol{y}+2(\boldsymbol{X}'\boldsymbol{X})\hat{\boldsymbol{\beta}} =\boldsymbol{0} \] となり,これを書き直すと,最小二乗法の正規方程式と言われる次の式を得ます \[ (\boldsymbol{X}'\boldsymbol{X})\hat{\boldsymbol{\beta}}=\boldsymbol{X}'\boldsymbol{y} \] この正規方程式を\(\hat{\boldsymbol{\beta}}\)について解けば \[ \hat{\boldsymbol{\beta}}=(\boldsymbol{X}'\boldsymbol{X})^{-1}\boldsymbol{X}'\boldsymbol{y} \] が得られます。

Rで実装

Rでの行列の掛け算は%*%という演算子で行います。

また,転置はt()逆行列solve()になります。

# 関数を定義
OLS <- function(X, y) {
  # 切片項の追加
  X = data.frame(Intercept = rep(1, nrow(X)), X)
  
  # 入力されたデータフレームをmatrixに変える
  if(is.data.frame(X)) X = as.matrix(X)
  if(is.data.frame(y)) y = as.matrix(y)
  
  # パラメータの推定: β = (X'X)^{-1} X'y
  beta = solve(t(X) %*% X) %*% t(X) %*% y
  return(beta)
}

# treesデータセットで試す
X = trees[c("Girth","Height")]
y = trees["Volume"]
OLS(X, y)
##                Volume
## Intercept -57.9876589
## Girth       4.7081605
## Height      0.3392512
# lm()と比較
lm(Volume ~ Girth + Height, trees)
## 
## Call:
## lm(formula = Volume ~ Girth + Height, data = trees)
## 
## Coefficients:
## (Intercept)        Girth       Height  
##    -57.9877       4.7082       0.3393

ちゃんと同じ値が出ました。

参考

経済分析のための統計的方法

経済分析のための統計的方法

計量経済学大全

計量経済学大全


  1. 残差平方和(Sum of Squared Residuals: SSR,あるいはResidual Sum of Squares: RSS)とも呼ばれます(統計学の分野だと残差,機械学習の分野だと誤差と呼ぶことが多いように思います)

  2. ※転置の基本公式から,\((\boldsymbol{X}\hat{\boldsymbol{\beta}})'=\hat{\boldsymbol{\beta}}'\boldsymbol{X}'\)

  3. \(\boldsymbol{y}'\boldsymbol{X}\hat{\boldsymbol{\beta}}=(\boldsymbol{X}\hat{\boldsymbol{\beta}})'\boldsymbol{y}=\hat{\boldsymbol{\beta}}'\boldsymbol{X}'\boldsymbol{y}\)