R 리지 회귀분석

1 개요[ | ]

R ridge regression
R 리지 회귀분석

2 전체 데이터 사용[ | ]

2.1 lambda 생략[ | ]

library(glmnet)  # glmnet()
library(Metrics) # mse()

x <- as.matrix(mtcars[, -1])
y <- mtcars[,1]

model <- glmnet(x, y, alpha=0)
length(model$lambda) # 100개의 lambda값 입력 → 모델 100개
coef(model)[,  1] #   1번 모델의 회귀계수
coef(model)[,100] # 100번 모델의 회귀계수
pred <- predict(model, newx=x)
mse(y, pred[,  1]) #   1번 모델의 MSE
mse(y, pred[,100]) # 100번 모델의 MSE

2.2 lambda 지정[ | ]

library(glmnet)  # glmnet()
library(Metrics) # mse()

x <- as.matrix(mtcars[, -1])
y <- mtcars[,1]

# lambda=1일 때
model1 <- glmnet(x, y, alpha=0, lambda=1)
coef(model1)
pred1 <- predict(model1, x)
mse(y, pred1) ## 5.00346049801398
# lambda=2일 때
model2 <- glmnet(x, y, alpha=0, lambda=2)
coef(model2)
pred2 <- predict(model2, x)
mse(y, pred2) ## 5.22249084160892

2.3 cv.glmnet()으로 lambda 찾기[ | ]

set.seed(12345)
library(glmnet) # cv.glmnet()

x <- as.matrix(mtcars[, -1])
y <- mtcars[,1]

cv <- cv.glmnet(x, y, alpha=0)
cv$lambda.min # MSE 최소화 lambda = 2.502772
cv$lambda.1se # MSE 최소화 lambda + 1 표준편차 = 10.10373
plot(cv)

3 데이터 분할[ | ]

3.1 범위 지정하여 lambda 찾기[ | ]

set.seed(12345)
library(caret)   # createDataPartition()
library(glmnet)  # glmnet()
library(Metrics) # mse()

df <- mtcars

# 데이터 분할 (7:3)
idx <- createDataPartition(df$mpg, list=F, p=0.7)
Train <- df[ idx,]
Test  <- df[-idx,]
Train.x <- as.matrix(Train[, -1])
Test.x  <- as.matrix(Test[, -1])
Train.y <- Train[,1]
Test.y  <- Test [,1]

lambda <- 10^seq(5, -20, by=-.05)
model <- glmnet(Train.x, Train.y, alpha=0, lambda=lambda)
pred <- predict(model, newx=Train.x)
mse <- c()
for( i in 1:length(lambda) )
{
  mse <- append(mse, mse(Train.y, pred[,i]))
}
length(lambda) # 확인한 모델(lambda) 수 = 501 
idx <- which.min(mse) # 훈련셋 MSE 최소 모델번호 = 405
idx
mse[idx] # 훈련셋 MSE 최소값 = 4.83314365454864
lambda[idx] # 훈련셋 MSE 최소 lambda = 6.30957344480189e-16
# 훈련셋 MSE 시각화
color = rep("#00000011", length(lambda))
color[idx] = "red"
plot(log(lambda), mse, pch = 16, col = color)

4 같이 보기[ | ]

문서 댓글 ({{ doc_comments.length }})
{{ comment.name }} {{ comment.created | snstime }}