Почему прогноз модели R gbm не соответствует модели?

Я использую каретку, чтобы соответствовать модели gbm. Когда я вызываю trainedGBM$finalModel$fit, я получаю правильный вывод.

Но когда я вызываю predict(trainedGBM$finalModel, origData, type="response"), я получаю очень разные результаты, а predict(trainedGBM$finalModel, type="response") дает все еще разные результаты, даже если прилагается origData. На мой взгляд, эти вызовы должны давать одинаковый результат. Может ли кто-нибудь помочь мне определить проблему?

library(caret)
library(gbm)

attach(origData)
gbmGrid <- expand.grid(.n.trees = c(2000), 
                       .interaction.depth = c(14:20), 
                       .shrinkage = c(0.005))
trainedGBM <- train(y ~ ., method = "gbm", distribution = "gaussian", 
                    data = origData, tuneGrid = gbmGrid, 
                    trControl = trainControl(method = "repeatedcv", number = 10, 
                                             repeats = 3, verboseIter = FALSE, 
                                             returnResamp = "all"))
ntrees <- gbm.perf(trainedGBM$finalModel, method="OOB")
data.frame(y, 
           finalModelFit = trainedGBM$finalModel$fit, 
           predictDataSpec = predict(trainedGBM$finalModel, origData, type="response", n.trees=ntrees), 
           predictNoDataSpec = predict(trainedGBM$finalModel, type="response", n.trees=ntrees))

Приведенный выше код дает следующие частичные результаты:

   y finalModelFit predictDataSpec predictNoDataSpec
9000     6138.8920        2387.182          2645.993
5000     3850.8817        2767.990          2467.157
3000     3533.1183        2753.551          2044.578
2500     1362.9802        2672.484          1972.361
1500     5080.2112        2449.185          2000.568
 750     2284.8188        2728.829          2063.829
1500     2672.0146        2359.566          2344.451
5000     3340.5828        2435.137          2093.939
   0     1303.9898        2377.770          2041.871
 500      879.9798        2691.886          2034.307
3000     2928.4573        2327.627          1908.876

person Nostradamus    schedule 14.07.2013    source источник
comment
Верно ли мое предположение, что это в пакете caret? Действительно неразумно заставлять людей гадать на такого рода вопрос, когда все, что вам нужно было сделать, это ввести library(_whatever_package_train_came_from)   -  person IRTFM    schedule 14.07.2013
comment
Кроме того, как отдельное нытье: использование attach является распространенным источником трудных для понимания ошибок. И надо было, конечно, полнее описать origData.   -  person IRTFM    schedule 14.07.2013
comment
Какое описание данных полезно здесь? У меня есть около 7000 записей, где y соответствует 26 функциям, как факторным, так и числовым.   -  person Nostradamus    schedule 14.07.2013


Ответы (1)


Исходя из вашего gbmGrid, только глубина вашего взаимодействия будет варьироваться от 14 до 20, а усадка и количество деревьев зафиксированы на 0,005 и 2000 соответственно. TrainedGBM предназначен для поиска только оптимального уровня взаимодействия в его нынешнем виде. Ваш ntrees, рассчитанный из gbm.perf, затем спрашивает, учитывая, что оптимальный уровень взаимодействия находится где-то между 14 и 20, каково оптимальное количество деревьев на основе критериев OOB. Поскольку прогнозы зависят от количества деревьев в модели, прогнозы, основанные на обученном GBM, будут использовать ntrees = 2000, а прогнозы, основанные на gbm.perf, будут использовать оптимальное количество ntrees, оцененное по этой функции. Это будет учитывать разницу между вашими trainedGBM$finalModel$fit и predict(trainedGBM$finalModel, type="response", n.trees=ntrees).

Чтобы показать пример, основанный на наборе данных радужной оболочки с использованием gbm в качестве классификации, а не модели регрессии.

library(caret)
library(gbm)

set.seed(42)

gbmGrid <- expand.grid(.n.trees = 100, 
                   .interaction.depth = 1:4, 
                   .shrinkage = 0.05)


trainedGBM <- train(Species ~ ., method = "gbm", distribution='multinomial',
                data = iris, tuneGrid = gbmGrid, 
                trControl = trainControl(method = "repeatedcv", number = 10, 
                                         repeats = 3, verboseIter = FALSE, 
                                         returnResamp = "all"))
print(trainedGBM)        

давать

# Resampling results across tuning parameters:

#  interaction.depth  Accuracy  Kappa  Accuracy SD  Kappa SD
#   1                  0.947     0.92   0.0407       0.061   
#   2                  0.947     0.92   0.0407       0.061   
#   3                  0.944     0.917  0.0432       0.0648  
#   4                  0.944     0.917  0.0395       0.0592  

# Tuning parameter 'n.trees' was held constant at a value of 100
# Tuning parameter 'shrinkage' was held constant at a value of 0.05
# Accuracy was used to select the optimal model using  the largest value.
# The final values used for the model were interaction.depth = 1, n.trees = 100
# and shrinkage = 0.05.     

Нахождение оптимального количества деревьев при условии оптимальной глубины взаимодействия:

ntrees <-  gbm.perf(trainedGBM$finalModel, method="OOB")
# Giving ntrees = 50

Если мы обучаем модель, варьируя количество деревьев и глубину взаимодействия:

gbmGrid2 <- expand.grid(.n.trees = 1:100, 
                   .interaction.depth = 1:4, 
                   .shrinkage = 0.05)

trainedGBM2 <- train(Species ~ ., method = "gbm", 
                data = iris, tuneGrid = gbmGrid2, 
                trControl = trainControl(method = "repeatedcv", number = 10, 
                                         repeats = 3, verboseIter = FALSE, 
                                         returnResamp = "all"))

print(trainedGBM2) 

# Tuning parameter 'shrinkage' was held constant at a value of 0.05
# Accuracy was used to select the optimal model using  the largest value.
# The final values used for the model were interaction.depth = 2, n.trees = 39
# and shrinkage = 0.05. 

Обратите внимание, что оптимальное количество деревьев, когда мы варьируем как количество деревьев, так и глубину взаимодействия, довольно близко к рассчитанному для gbm.perf.

person Community    schedule 14.07.2013