मैं एक प्रतिगमन मॉडल बनाना चाहता हूं जो कई ओएलएस मॉडल का औसत है, प्रत्येक पूर्ण डेटा के सबसेट पर आधारित है। इसके पीछे का विचार इस कागज पर आधारित है । मैं कश्मीर सिलवटों का निर्माण करता हूं और कश्मीर ओएलएस मॉडल का निर्माण करता हूं, प्रत्येक में बिना किसी फोल्ड के डेटा। मैं फिर अंतिम मॉडल प्राप्त करने के लिए प्रतिगमन गुणांक औसत करता हूं।
यह मुझे कुछ बेतरतीब वन प्रतिगमन की तरह ही हमला करता है, जिसमें कई प्रतिगमन पेड़ बनाए जाते हैं और औसत होते हैं। हालाँकि, औसतन ओएलएस मॉडल का प्रदर्शन पूरे डेटा पर केवल एक ओएलएस मॉडल बनाने से भी बदतर लगता है। मेरा सवाल है: क्या एक सैद्धांतिक कारण है कि कई ओएलएस मॉडल का औसत गलत या अवांछनीय है? क्या हम ओवरफिटिंग को कम करने के लिए कई ओएलएस मॉडल की औसत उम्मीद कर सकते हैं? नीचे एक आर उदाहरण है।
#Load and prepare data
library(MASS)
data(Boston)
trn <- Boston[1:400,]
tst <- Boston[401:nrow(Boston),]
#Create function to build k averaging OLS model
lmave <- function(formula, data, k, ...){
lmall <- lm(formula, data, ...)
folds <- cut(seq(1, nrow(data)), breaks=k, labels=FALSE)
for(i in 1:k){
tstIdx <- which(folds==i, arr.ind = TRUE)
tst <- data[tstIdx, ]
trn <- data[-tstIdx, ]
assign(paste0('lm', i), lm(formula, data = trn, ...))
}
coefs <- data.frame(lm1=numeric(length(lm1$coefficients)))
for(i in 1:k){
coefs[, paste0('lm', i)] <- get(paste0('lm', i))$coefficients
}
lmnames <- names(lmall$coefficients)
lmall$coefficients <- rowMeans(coefs)
names(lmall$coefficients) <- lmnames
lmall$fitted.values <- predict(lmall, data)
target <- trimws(gsub('~.*$', '', formula))
lmall$residuals <- data[, target] - lmall$fitted.values
return(lmall)
}
#Build OLS model on all trn data
olsfit <- lm(medv ~ ., data=trn)
#Build model averaging five OLS
olsavefit <- lmave('medv ~ .', data=trn, k=5)
#Build random forest model
library(randomForest)
set.seed(10)
rffit <- randomForest(medv ~ ., data=trn)
#Get RMSE of predicted fits on tst
library(Metrics)
rmse(tst$medv, predict(olsfit, tst))
[1] 6.155792
rmse(tst$medv, predict(olsavefit, tst))
[1] 7.661 ##Performs worse than olsfit and rffit
rmse(tst$medv, predict(rffit, tst))
[1] 4.259403