Список индекса вне диапазона при сохранении точно настроенной модели Tensorflow

Я пытаюсь настроить предварительно обученную модель BERT из Huggingface с помощью Tensorflow. Все работает гладко, модель строится и обучается без ошибок. Но когда я пытаюсь сохранить модель, она останавливается с ошибкой IndexError: list index out of range. Я использую Google Colab с TPU.

Любая помощь приветствуется!

Код:

import tensorflow as tf
from tensorflow.keras import activations, optimizers, losses
from transformers import TFBertModel

def create_model(max_sequence, model_name, num_labels):
    bert_model = TFBertModel.from_pretrained(model_name)
    input_ids = tf.keras.layers.Input(shape=(max_sequence,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.layers.Input((max_sequence,), dtype=tf.int32, name='attention_mask')
    output = bert_model([input_ids, attention_mask])[0]
    output = output[:, 0, :]
    output = tf.keras.layers.Dense(num_labels, activation='sigmoid')(output)
    model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=output)
    return model

with strategy.scope():
  model = create_model(20, 'bert-base-uncased', 1)
  opt = optimizers.Adam(learning_rate=3e-5)
  loss = 'binary_crossentropy'
  model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])

model.fit(tfdataset_train, batch_size=32, epochs=2)
SAVE_PATH = 'path/to/save/location'
model.save(SAVE_PATH)

Ошибка:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-22-255116b49022> in <module>()
      1 SAVE_PATH = 'path/to/save/location'
----> 2 model.save(SAVE_PATH,save_format='tf')

50 frames
/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
    372                     output[tensor_name] = input
    373                 else:
--> 374                     output[parameter_names[i]] = input
    375             elif isinstance(input, allowed_types) or input is None:
    376                 output[parameter_names[i]] = input

IndexError: list index out of range

Модель с фигурами: Модель Tensorflow


comment
это решает проблему, но тогда модель не обучается правильно.   -  person user_007    schedule 18.06.2021


Ответы (2)


Задача решена!

Удаление одного из двух входных слоев (т.е. маски_ внимания) решило проблему.

Рабочий код - ›

import tensorflow as tf
from tensorflow.keras import activations, optimizers, losses
from transformers import TFBertModel

def create_model(max_sequence, model_name, num_labels):
    bert_model = TFBertModel.from_pretrained(model_name)
    input_ids = tf.keras.layers.Input(shape=(max_sequence,), dtype=tf.int32, name='input_ids')
    #attention_mask = tf.keras.layers.Input((max_sequence,), dtype=tf.int32, name='attention_mask')
    #output = bert_model([input_ids, attention_mask])[0]
    output = bert_model([input_ids])[0]
    output = output[:, 0, :]
    output = tf.keras.layers.Dense(num_labels, activation='sigmoid')(output)
    #model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=output)
    model = tf.keras.models.Model(inputs=[input_ids], outputs=output)
    return model

with strategy.scope():
  model = create_model(20, 'bert-base-uncased', 1)
  opt = optimizers.Adam(learning_rate=3e-5)
  loss = 'binary_crossentropy'
  model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])

model.fit(tfdataset_train, batch_size=32, epochs=2)
SAVE_PATH = 'path/to/save/location'
model.save(SAVE_PATH)
person Haag    schedule 10.03.2021
comment
Спасибо! Просто чтобы выделить что-то: эта ошибка появилась у меня при использовании Tf == 2.3 и трансформаторов == 4.3. - person user_007; 17.06.2021

Решение - изменить:

output = bert_model([input_ids, attention_mask])[0]

to

output = bert_model.bert([input_ids, attention_mask])[0]

Ссылка: https://github.com/huggingface/transformers/issues/3627

Я поддержал опубликованное вами решение, но затем обнаружил проблему с моделью во время обучения. Не очень хорошо сходится.

person user_007    schedule 18.06.2021