Модель seq2seq на основе Google TensorFlow дает сбой во время обучения

Я пытался использовать модель Google RNN на основе seq2seq.

Я тренировал модель для суммирования текста и ввожу текстовые данные размером примерно 1 ГБ. Модель быстро заполняет всю мою оперативную память (8 ГБ), начинает заполнять даже память подкачки (еще 8 ГБ) и вылетает после того, как мне приходится жестко выключаться.

Конфигурация моей сети LSTM выглядит следующим образом:

model: AttentionSeq2Seq
model_params:
  attention.class: seq2seq.decoders.attention.AttentionLayerDot
  attention.params:
    num_units: 128
  bridge.class: seq2seq.models.bridges.ZeroBridge
  embedding.dim: 128
  encoder.class: seq2seq.encoders.BidirectionalRNNEncoder
  encoder.params:
    rnn_cell:
      cell_class: GRUCell
      cell_params:
        num_units: 128
      dropout_input_keep_prob: 0.8
      dropout_output_keep_prob: 1.0
      num_layers: 1
  decoder.class: seq2seq.decoders.AttentionDecoder
  decoder.params:
    rnn_cell:
      cell_class: GRUCell
      cell_params:
        num_units: 128
      dropout_input_keep_prob: 0.8
      dropout_output_keep_prob: 1.0
      num_layers: 1
  optimizer.name: Adam
  optimizer.params:
    epsilon: 0.0000008
  optimizer.learning_rate: 0.0001
  source.max_seq_len: 50
  source.reverse: false
  target.max_seq_len: 50

Я попытался уменьшить размер пакета с 32 до 16, но это все равно не помогло. Какие конкретные изменения я должен внести, чтобы моя модель не занимала всю оперативную память и не давала сбой? (Например, уменьшение размера данных, уменьшение количества сложенных ячеек LSTM, дальнейшее уменьшение размера пакета и т. д.)

Моя система работает под управлением Python 2.7x, TensorFlow версии 1.1.0 и CUDA 8.0. Система имеет Nvidia Geforce GTX-1050Ti (768 ядер CUDA) с 4 ГБ памяти, 8 ГБ ОЗУ и еще 8 ГБ памяти подкачки.


person Rudresh Panchal    schedule 19.06.2017    source источник


Ответы (1)


Ваша модель выглядит довольно маленькой. Единственное, что кажется большим, это данные о поездах. Убедитесь, что ваша функция get_batch() не содержит ошибок. Вполне возможно, что в каждой партии вы фактически загружаете весь набор данных для обучения, если там есть ошибка.

Чтобы быстро доказать это, просто сократите размер тренировочных данных до чего-то очень маленького (например, 1/10 от текущего размера) и посмотрите, поможет ли это. Обратите внимание, что это не должно помочь, потому что вы используете мини-партию. Но если это решит проблему, исправьте свою функцию get_batch().

person Bo Shao    schedule 19.06.2017
comment
Эй, спасибо за ваш ответ, но я сомневаюсь, что это так, поскольку функция get_batch была абстрагирована и внутренне обрабатывается фреймворком, написанным Google. Поэтому, если в фреймворке нет ошибки, функция get_batch верна. Существуют ли какие-либо другие гиперпараметры, которые я могу изменить, чтобы решить мою проблему? - person Rudresh Panchal; 20.06.2017
comment
Я не уверена. Другим действительно сложно слепо отлаживать проблему. Однако проблема, скорее всего, вызвана ошибкой. Попробуйте уменьшить общий набор поездов примерно до 1 МБ. Если это не поможет, проверьте правильность передачи всех гиперпараметров, особенно seq_len. - person Bo Shao; 20.06.2017