盆暗の学習記録

データサイエンス ,エンジニアリング,ビジネスについて日々学んだことの備忘録としていく予定です。初心者であり独学なので内容には誤りが含まれる可能性が大いにあります。

[R]交差項や2乗項を作る

どうやるのかちょっと悩んだのでメモ。

交差項

交差項等を作りたいときはmodel.matrix()を使うといいようです。

object引数にformulaを指定して,変数名を*でかけ合わせて交差項を作ります。

library(tidyverse)
model.matrix(object = ~ Sepal.Length * Sepal.Width, data = iris) %>% head()
##   (Intercept) Sepal.Length Sepal.Width Sepal.Length:Sepal.Width
## 1           1          5.1         3.5                    17.85
## 2           1          4.9         3.0                    14.70
## 3           1          4.7         3.2                    15.04
## 4           1          4.6         3.1                    14.26
## 5           1          5.0         3.6                    18.00
## 6           1          5.4         3.9                    21.06

切片を除く

-1を入れれば切片は作られません

model.matrix(object = ~  -1 + Sepal.Length * Sepal.Width, data = iris) %>% head()
##   Sepal.Length Sepal.Width Sepal.Length:Sepal.Width
## 1          5.1         3.5                    17.85
## 2          4.9         3.0                    14.70
## 3          4.7         3.2                    15.04
## 4          4.6         3.1                    14.26
## 5          5.0         3.6                    18.00
## 6          5.4         3.9                    21.06

もっと多数の場合

「手持ちのデータセットのうち,2つの変数の交差項だけを,すべての変数について組み合わせる」というようなことをやりたい場合は次のようにやればよさそうです。

# 2つの変数名の全ての組み合わせ
colnames_df <- expand.grid(colnames(iris), colnames(iris))

# formulaにする
fmla <- as.formula(str_c("~ -1 + ", str_c(colnames_df[,1], " * ", colnames_df[,2], collapse = " + ")))
> fmla
## ~-1 + Sepal.Length * Sepal.Length + Sepal.Width * Sepal.Length + 
##     Petal.Length * Sepal.Length + Petal.Width * Sepal.Length + 
##     Species * Sepal.Length + Sepal.Length * Sepal.Width + Sepal.Width * 
##     Sepal.Width + Petal.Length * Sepal.Width + Petal.Width * 
##     Sepal.Width + Species * Sepal.Width + Sepal.Length * Petal.Length + 
##     Sepal.Width * Petal.Length + Petal.Length * Petal.Length + 
##     Petal.Width * Petal.Length + Species * Petal.Length + Sepal.Length * 
##     Petal.Width + Sepal.Width * Petal.Width + Petal.Length * 
##     Petal.Width + Petal.Width * Petal.Width + Species * Petal.Width + 
##     Sepal.Length * Species + Sepal.Width * Species + Petal.Length * 
##     Species + Petal.Width * Species + Species * Species

このformulaをmodel.matrix()に入れます

# これをmodel.matrix()に入れると全組み合わせができる
origin_dumy_interact <- model.matrix(object = fmla, data = iris)

# (1)元の変数,(2)カテゴリカル変数をダミー変数にしたもの,(3)交差項の3つが入っている
origin_dumy_interact %>% as_data_frame() %>% glimpse()
## Observations: 150
## Variables: 21
## $ Sepal.Length                     <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4,...
## $ Sepal.Width                      <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9,...
## $ Petal.Length                     <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7,...
## $ Petal.Width                      <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4,...
## $ Speciessetosa                    <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...
## $ Speciesversicolor                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ Speciesvirginica                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Sepal.Length:Sepal.Width`       <dbl> 17.85, 14.70, 15.04, 14.26, 1...
## $ `Sepal.Length:Petal.Length`      <dbl> 7.14, 6.86, 6.11, 6.90, 7.00,...
## $ `Sepal.Length:Petal.Width`       <dbl> 1.02, 0.98, 0.94, 0.92, 1.00,...
## $ `Sepal.Length:Speciesversicolor` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Sepal.Length:Speciesvirginica`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Sepal.Width:Petal.Length`       <dbl> 4.90, 4.20, 4.16, 4.65, 5.04,...
## $ `Sepal.Width:Petal.Width`        <dbl> 0.70, 0.60, 0.64, 0.62, 0.72,...
## $ `Sepal.Width:Speciesversicolor`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Sepal.Width:Speciesvirginica`   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Petal.Length:Petal.Width`       <dbl> 0.28, 0.28, 0.26, 0.30, 0.28,...
## $ `Petal.Length:Speciesversicolor` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Petal.Length:Speciesvirginica`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Petal.Width:Speciesversicolor`  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ `Petal.Width:Speciesvirginica`   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...

できました。交差項だけでなく,交差項に使った元の変数も(factor,characterはダミー変数化されて)入っていますね。

交差項だけ取り出したい場合は":"でselectすればいいし,このまま使っても良い‥という感じですね。

留意点

ただ,留意点はmodel.matrix()がna.omit()のようなNAのリストワイズ除去も行ってしまう点ですね…。

一応,出力されるmatrixのrownamesを使えば何行目が除去されたかがわかるので,この情報を使えば他のデータと行を合わせることは容易にできます。

2乗項

ダミー変数や数値でない変数を避けつつ2乗項をつくります。

model.matrix()のように一発でやってくれる関数を知らないため,自作関数をselect_if()mutate_all()rename_all()に通していくことで処理してみます。

# 関数を定義 ----------------------------------------
# 値が2水準より多いかどうかを判定する関数(select_if()用)
is_more_than_2_level <- function(x){
  is_more_than_2_level = unique(x) %>% na.omit() %>% length() > 2
  return(is_more_than_2_level)
}
# 2乗する関数(mutate_all()用)
as_quadratic_term <- function(x){
  x^2
}
# 2乗項であることを示す名前にする(rename_all()用)
name_add_sqr <- function(name) {
  newname <- str_c("sqr_",name)
  return(newname)
}

# 処理の実行 -----------------------------------------
quadratic_terms <- iris %>% 
  # ダミー変数でないものを取り出し,数値のものを取り出す
  select_if(is_more_than_2_level) %>% select_if(is.numeric) %>% 
  # 2乗する
  mutate_all(as_quadratic_term) %>% 
  # 2乗項であることを示す名前にする
  rename_all(name_add_sqr)
> head(quadratic_terms)
##   sqr_Sepal.Length sqr_Sepal.Width sqr_Petal.Length sqr_Petal.Width
## 1            26.01           12.25             1.96            0.04
## 2            24.01            9.00             1.96            0.04
## 3            22.09           10.24             1.69            0.04
## 4            21.16            9.61             2.25            0.04
## 5            25.00           12.96             1.96            0.04
## 6            29.16           15.21             2.89            0.16

できました。