Tidymodels: медленная настройка гиперпараметров с данными с несколькими предикторами

В настоящее время я пытаюсь подогнать случайную модель леса с настройкой гиперпараметров, используя структуру tidymodels на фреймворке данных с 101064 строками и 64 столбцами. У меня есть сочетание категориальных и непрерывных предикторов, а моя конечная переменная - это категориальная переменная с 3 категориями, поэтому у меня проблема с многоклассовой классификацией.

Проблема, с которой я столкнулся, заключается в том, что этот процесс, даже с параллельной обработкой, занимает на моей машине примерно от 6 до 8 часов. Поскольку 101064 - это не огромный объем данных, я подозреваю, что я что-то делаю неправильно или эффективно (или и то, и другое!). К сожалению, я не могу поделиться точным набором данных из-за конфиденциальности, но код, который я поделился ниже, предлагает очень близкую копию исходного набора данных от количества уровней в каждой категориальной переменной до количества NA в каждом столбце.

У меня есть несколько замечаний по поводу приведенного ниже кода, которые могут дать представление о том, почему я сделал то, что сделал. Во-первых, я разделяю обучающие и тестовые наборы по идентификатору группы, а не по строкам. Набор данных вложен, где есть несколько строк, которые соответствуют одному и тому же идентификатору группы. В идеале мне нужна модель, которая может изучать закономерности по идентификаторам групп. Следовательно, не должно быть общих идентификаторов групп между складками обучения и тестирования, а также общих идентификаторов групп между складками анализа и оценки в складках перекрестной проверки.

Во-вторых, я включил step_unknown, потому что Random Forest не любит значения NA. Я включил step_novel в качестве меры предосторожности на случай, если будущие данные будут иметь категориальные уровни, которых не видели текущие данные. Я не уверен, когда использовать step_unknown против step_novel, и я не уверен, разумно ли использовать их вместе, поэтому любые разъяснения будут очень благодарны. Я использовал step_other и step_dummy для One Hot Encode категориальных предикторов. step_impute_median был включен, чтобы в данных не было НП, чтобы не было жалоб от Random Forest. step_downsample использовался для устранения дисбаланса классов в переменной результата, я использовал понижающую дискретизацию, чтобы получить меньше наблюдений на этапе построения модели, но, похоже, это не привело к сокращению времени обучения.

Мои вопросы:

  1. Есть ли причина, по которой настройка модели занимает около 6 часов, и можно ли ее оптимизировать в дальнейшем? Я открыт для использования уменьшения размерности и был бы признателен за некоторые учебные пособия для этого в рамках контролируемого конвейера машинного обучения с использованием фреймворка tidymodels.

  2. Правильно ли я указал и использовал рецепты? Я не совсем уверен в этом. Я упомянул выше, что я думаю, что делаю, но действительно ли это то, что я делаю, и это лучший способ сделать это? Я готов переформулировать этап рецептов.

Любая помощь по этому поводу будет очень признательна. Этот набор данных невелик, поэтому резкое сокращение времени на обучение модели позволило бы мне запустить его в производство.

Я запускаю этот код на своем локальном компьютере, который представляет собой MacBook Pro с 8-ядерным процессором 2,4 ГГц и 32 ГБ памяти.

library(tidyverse)
library(tidymodels)
library(themis)
library(finetune)
library(doParallel)
library(parallel)
library(ranger)
library(future)
library(doFuture)


# Create Synthetic data that closely mimics actual dataset ----
## Categorical predictors
categorical_predictor1 <- rep(c("cat1", "cat2", "cat3", "cat4", "cat5"), times = c(43281, 29088, 9881, 8874, 9940))
categorical_predictor2 <- rep(c("cat1", "cat2", "cat3", "cat4", "cat5"), times = c(2522, 21302, 20955, 36859, 19426))
categorical_predictor3 <- rep(c("cat1", "cat2"), times = c(15950, 85114))
categorical_predictor4 <- rep(c("cat1", "cat2", "cat3", "cat4", "cat5", "cat6", "cat7"), times = c(52023, 16666, 13662, 7045, 2644, 1798, 7226))
categorical_predictor5 <- rep(c("cat1", "cat2", "cat3"), times = c(52613, 14903, 33548))
categorical_predictor6 <- rep(c("cat1", "cat2", "cat3", "cat4"), times = c(13662, 16666, 18713, 52023))
categorical_predictor7 <- rep(c("cat1", "cat2", "cat3", "cat4", "cat5", "cat6", NA), times = c(44210, 11062, 8846, 4638, 1778, 4595, 25935))
categorical_predictor8 <- rep(c("cat1", "cat2", "cat3", "cat4", NA), times = c(11062, 8846, 11011, 44210, 25935))
categorical_predictor9 <- rep(c("cat1", "cat2", "cat3", "cat4", "cat5", "cat6", NA), times = c(11649, 10215, 9783, 7580, 5649, 30253, 25935))
categorical_predictor10 <- rep(c("cat1", "cat2", "cat3", "cat4", "cat5", "cat6", NA), times = c(12563, 11649, 10215, 9783, 7580, 23339, 25935))
categorical_predictor11 <- rep(c("cat1", "cat2", NA), times = c(14037, 61092, 25935))
categorical_predictor12 <- rep(c("cat1", "cat2", "cat3", NA), times = c(15042, 35676, 23861, 26485))


# Outcome variable
outcome_variable <- rep(c("cat1", "cat2", "cat3"), times = c(21375, 49824, 29865))

## Continuous Predictors: Values are not normalized
continuous_predictor1 <- runif(n = 101064, min = 0, max = 90)
continuous_predictor2 <- runif(n = 101064, min = 0, max = 95.4)
continuous_predictor3 <- runif(n = 101064, min = 0, max = 14.1515)
continuous_predictor4 <- runif(n = 101064, min = 0, max = 85)
continuous_predictor5 <- runif(n = 101064, min = 0, max = 71)
continuous_predictor6 <- runif(n = 101064, min = -236, max = 97)
continuous_predictor7 <- runif(n = 101064, min = -40, max = 84)
continuous_predictor8 <- runif(n = 101064, min = 2015, max = 2019)
continuous_predictor9 <- runif(n = 101064, min = 0, max = 6)
continuous_predictor10 <- runif(n = 101064, min = 2, max = 26)
continuous_predictor11 <- runif(n = 101064, min = 0, max = 26)
continuous_predictor12 <- runif(n = 101064, min = 0.1365, max = 0.4352)
continuous_predictor13 <- runif(n = 101064, min = 0.1282, max = 0.4860)
continuous_predictor14 <- runif(n = 101064, min = 0.1232, max = 0.4643)
continuous_predictor15 <- runif(n = 101064, min = 0.1365, max = 0.4885)
continuous_predictor16 <- runif(n = 101064, min = 107, max = 218.6)
continuous_predictor17 <- runif(n = 101064, min = 0.6667, max = 16.333)
continuous_predictor18 <- runif(n = 101064, min = 3.479, max = 7.177)
continuous_predictor19 <- runif(n = 101064, min = 0.8292, max = 3.3100)
continuous_predictor20 <- runif(n = 101064, min = 49.33, max = 101.70)
continuous_predictor21 <- runif(n = 101064, min = 0.07333, max = 0.42534)
continuous_predictor22 <- runif(n = 101064, min = 0.08727, max = 0.41762)
continuous_predictor23 <- runif(n = 101064, min = 0.1241, max = 0.4673)
continuous_predictor24 <- runif(n = 101064, min = 0.07483, max = 0.41192)
continuous_predictor25 <- runif(n = 101064, min = 446.1, max = 561.0)
continuous_predictor26 <- runif(n = 101064, min = 2.333, max = 24)
continuous_predictor27 <- runif(n = 101064, min = 14.52, max = 18.23)
continuous_predictor28 <- runif(n = 101064, min = 0.5463, max = 3.488)
continuous_predictor29 <- runif(n = 101064, min = 150.7, max = 251.9)
continuous_predictor30 <- runif(n = 101064, min = 0.1120, max = 0.4603)
continuous_predictor31 <- runif(n = 101064, min = 0.1231, max = 0.4766)
continuous_predictor32 <- runif(n = 101064, min = 0.1271, max = 0.4857)
continuous_predictor33 <- runif(n = 101064, min = 0.1152, max = 0.4613)
continuous_predictor34 <- runif(n = 101064, min = 238.6, max = 329.4)
continuous_predictor35 <- runif(n = 101064, min = 5.333, max = 19.667)
continuous_predictor36 <- runif(n = 101064, min = 7.815, max = 10.929)
continuous_predictor37 <- runif(n = 101064, min = 0.8323, max = 2.8035)
continuous_predictor38 <- runif(n = 101064, min = 140.9, max = 195.5)
continuous_predictor39 <- runif(n = 101064, min = 0.1098, max = 0.4581)
continuous_predictor40 <- runif(n = 101064, min = 0.08825, max = 0.41360)
continuous_predictor41 <- runif(n = 101064, min = 0.1209, max = 0.4510)
continuous_predictor42 <- runif(n = 101064, min = 0.1048, max = 0.4498)
continuous_predictor43 <- runif(n = 101064, min = 312.2, max = 382.2)
continuous_predictor44 <- runif(n = 101064, min = 2.667, max = 18)
continuous_predictor45 <- runif(n = 101064, min = 10.22, max = 12.49)
continuous_predictor46 <- runif(n = 101064, min = 1.077, max = 2.968)
continuous_predictor47 <- runif(n = 101064, min = 72.18, max = 155.71)

## Continuous Predictors: Values have NAs
continuous_predictor_withNA1 <- c(runif(n = 101064 - 26485, min = 1, max = 3), rep(NA, times = 26485))
continuous_predictor_withNA2 <- c(runif(n = 101064 - 26485, min = 1, max = 3), rep(NA, times = 26485))

## Group ID
set.seed(123)
group_id <- sample(c(1,2,3,4,5,6,7,9,10,11,13,14,16,17,18,19,20,21,22,24,25,26,27,28,29,30,31,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,107,109,110,111,112,125,126,161,162,163,164,165,178,179,180,184,185,186,187,188,189,197,198,199,209,210,211,212,213,214,231,232,233,234,239,240,250,251,252,255,256,257,258,259,260,261,508,509,510,602,721,730),
                   size = 101064,
                   replace = TRUE,
                   prob = c(0.010300404,0.003661047,0.005758727,0.002849679,0.005976411,0.006738304,0.004957255,0.008727143,0.007757461,0.00530357,0.00867767,0.003839151,0.007836618,0.004531782,0.007678303,0.013150083,0.003364205,0.005194728,0.002750732,0.005778517,0.009825457,0.010488403,0.009399984,0.006105042,0.011101876,0.006490936,0.008459986,0.003918309,0.009083353,0.001583155,0.005382728,0.013832819,0.004828623,0.004670308,0.007213251,0.006570094,0.006035779,0.007322093,0.006570094,0.002077891,0.000979577,0.006926304,0.007124199,0.005521254,0.007618935,0.00335431,0.002968416,0.005442096,0.016069026,0.005174939,0.001820629,0.008578722,0.00213726,0.00142484,0.014644186,0.006688831,0.003799573,0.008430302,0.004581255,0.002552838,0.012833452,0.00620399,0.003799573,0.004729676,0.005639991,0.010824824,0.010735771,0.004343782,0.008934932,0.005679569,0.004096414,0.011141455,0.011853875,0.00354231,0.006312832,0.001553471,0.009162511,0.006550305,0.007688198,0.002354943,0.002730943,0.005085886,0.004808834,0.013634924,0.006233674,0.007124199,0.007915776,0.006431568,0.003957888,0.005422307,0.002394522,0.00865788,0.008093881,0.002592417,0.001157682,0.005758727,0.004897887,0.002364838,0.004749466,0.005194728,0.009795773,0.007054936,0.003601678,0.006362305,0.00848967,0.011448191,0.003364205,0.006431568,0.005224412,0.007282514,0.007242935,0.008074092,0.009686931,0.00670862,0.003571994,0.008717249,0.007806934,0.004135993,0.006253463,0.006302937,0.007846513,0.003680836,0.006095148,0.00264189,0.004581255,0.004838518,0.001454524,0.004571361,0.005926937,0.002236207,0.007361672,0.006332621,0.011952822,0.013852608,0.009775984,0.007124199,0.013733872,0.007143988,0.006827357,0.00425473,0.007094514,0.005085886,0.013308399,0.007480409,0.007737671,0.004551571,0.00744083,0.012576189,0.008796406,0.010884192,0.0063722,0.01006293))


## Join to make a dataframe
df <- tibble(group_id, 
             categorical_predictor1,
             categorical_predictor2,
             categorical_predictor3,
             categorical_predictor4,
             categorical_predictor5,
             categorical_predictor6,
             categorical_predictor7,
             categorical_predictor8,
             categorical_predictor9,
             categorical_predictor10,
             categorical_predictor11,
             categorical_predictor12,
             continuous_predictor1,
             continuous_predictor2,
             continuous_predictor3,
             continuous_predictor4,
             continuous_predictor5,
             continuous_predictor6,
             continuous_predictor7,
             continuous_predictor8,
             continuous_predictor9,
             continuous_predictor10,
             continuous_predictor11,
             continuous_predictor12,
             continuous_predictor13,
             continuous_predictor14,
             continuous_predictor15,
             continuous_predictor16,
             continuous_predictor17,
             continuous_predictor18,
             continuous_predictor19,
             continuous_predictor20,
             continuous_predictor21,
             continuous_predictor22,
             continuous_predictor23,
             continuous_predictor24,
             continuous_predictor25,
             continuous_predictor26,
             continuous_predictor27,
             continuous_predictor28,
             continuous_predictor29,
             continuous_predictor30,
             continuous_predictor31,
             continuous_predictor32,
             continuous_predictor33,
             continuous_predictor34,
             continuous_predictor35,
             continuous_predictor36,
             continuous_predictor37,
             continuous_predictor38,
             continuous_predictor39,
             continuous_predictor40,
             continuous_predictor41,
             continuous_predictor42,
             continuous_predictor43,
             continuous_predictor44,
             continuous_predictor45,
             continuous_predictor46,
             continuous_predictor47,
             continuous_predictor_withNA1,
             continuous_predictor_withNA2,
             outcome_variable)

df <- df %>% 
  mutate_if(is.character, as.factor) %>% 
  mutate(.row = row_number())

# Split Data ----
## Split the data while keeping group ids separate, groups will not be split up across training and testing sets
set.seed(123)
holdout_group_id <- sample(unique(df$group_id), size = 5)

indices <- list(
  analysis = df %>% filter(!(group_id %in% holdout_group_id)) %>% pull(.row),
  assessment = df %>% filter(group_id %in% holdout_group_id) %>% pull(.row)
)

## Remove row column - no longer required
df <- df %>% 
  select(-.row)

split <- make_splits(indices, df)
df_train <- training(split)
df_test <- testing(split)

## Create Cross Validation Folds
set.seed(123)
folds <- group_vfold_cv(df_train, group = "group_id", v = 5)

# Create Recipe ----
## Define a recipe to be applied to the data
df_recipe <- recipe(outcome_variable ~ ., data = df_train) %>% 
  update_role(group_id, new_role = "ID") %>% 
  step_unknown(all_nominal_predictors()) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_other(all_nominal_predictors(), threshold = 0.1, other = "other_category") %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_impute_median(continuous_predictor_withNA1, continuous_predictor_withNA2) %>% 
  themis::step_downsample(all_outcomes(), skip = TRUE) 


# Define Model ----
## Initialise model with tuneable hyperparameters
rf_spec <- rand_forest(trees = tune(), mtry = tune()  ) %>% 
  set_engine("ranger", importance = "permutation") %>% 
  set_mode("classification")

# Define Workflow to connect Recipe and Model ----
rf_workflow <- workflow() %>% 
  add_recipe(df_recipe) %>% 
  add_model(rf_spec)

# Train and Tune Model ----
## Define a random grid for hyperparameters to vary over
set.seed(123)
rf_grid <- grid_latin_hypercube(
  trees(),
  mtry() %>% finalize(df_train %>% dplyr::select(-group_id, -outcome_variable)),
  size = 20)

## Tune Model using Parallel Processing
all_cores <- parallel::detectCores(logical=FALSE) - 1
registerDoFuture() # Register backend
cl <- makeCluster(all_cores, setup_strategy = "sequential")

set.seed(123)
rf_tuned <-rf_workflow %>% 
    tune_race_win_loss(resamples = folds,
                       grid = rf_grid,
                       control = control_race(save_pred = TRUE),
                       metrics = metric_set(roc_auc, accuracy)) 

person Junaid Butt    schedule 21.06.2021    source источник


Ответы (1)


У меня есть пара мыслей, которые могут помочь.

  • Я рекомендую начинать без настройки, чтобы иметь хорошее представление о том, сколько времени займет работа, и базовые показатели, которые вы получите с ненастроенным случайным лесом. Возможно, вы уже это сделали, но зачастую от настройки случайного леса вы не добьетесь значительных улучшений. Используйте fit(rf_workflow, df_train), чтобы знать, с чем вы работаете, и можно было настроить.
  • На самом деле вам не нужно использовать step_dummy() со случайным лесом. Вероятно, это не слишком сильно замедляет вас, но нет причин добавлять его.
  • Вы почти наверняка не захотите устанавливать importance = "permutation" во время передискретизации или настройки. В любом случае вы не сохраняете эти модели для прогнозов, и для вычисления оценок важности требуется гораздо больше времени, чем просто подгонка.

Если я уберу step_dummy() и оценку важности, я смогу подогнать рабочий процесс вашей модели к данным этого примера менее чем за минуту. Вы умножите это на 5 для вашего folds и на 20 для вашей сетки, так что ~ 100 минут или около того, без параллельной обработки или гоночных методов (которые, конечно, очень помогут). Я полагаю, что оценка важности является большой проблемой, но вы должны иметь возможность немного изучить это и выяснить.

person Julia Silge    schedule 22.06.2021
comment
Привет, Джулия, я попробовал дать тебе совет, чтобы увидеть, как он влияет на время тренировки. Я записал результаты в этом посте: community.rstudio.com/t/ Короче говоря, отказ от использования step_dummy() и importance = "permutation" определенно помог, особенно последнего. Я хотел спросить, если я использую XGBoost, мне потребуется step_dummy()? Это было бы полезно знать при составлении рецептов для сравнения типов моделей Random Forest и XGBoost с использованием workflowsets. - person Junaid Butt; 24.06.2021
comment
Вам действительно нужно step_dummy(), когда вы подходите к модели xgboost, да, потому что такая модель требует всех числовых предикторов. - person Julia Silge; 25.06.2021