Как проверить веса после каждой эпохи в модели Keras

Я использую последовательную модель в Keras. Я хотел бы проверять вес модели после каждой эпохи. Не могли бы вы подсказать мне, как это сделать.

model = Sequential()
model.add(Embedding(max_features, 128, dropout=0.2))
model.add(LSTM(128, dropout_W=0.2, dropout_U=0.2))  
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',optimizer='adam',metrics['accuracy'])
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=5 validation_data=(X_test, y_test))

Заранее спасибо.


person Kiran Baktha    schedule 04.02.2017    source источник


Ответы (1)


То, что вы ищете, это функция CallBack. Обратный вызов — это функция Keras, которая многократно вызывается во время обучения в ключевых точках. Это может быть после партии, эпохи или всей тренировки. См. здесь документацию и список существующих обратных вызовов.

Что вам нужно, так это настраиваемый обратный вызов, который можно создать с помощью объекта LambdaCallBack.

from keras.callbacks import LambdaCallback

model = Sequential()
model.add(Embedding(max_features, 128, dropout=0.2))
model.add(LSTM(128, dropout_W=0.2, dropout_U=0.2))  
model.add(Dense(1))
model.add(Activation('sigmoid'))

print_weights = LambdaCallback(on_epoch_end=lambda batch, logs: print(model.layers[0].get_weights()))

model.compile(loss='binary_crossentropy',optimizer='adam',metrics['accuracy'])
model.fit(X_train, 
          y_train, 
          batch_size=batch_size, 
          nb_epoch=5 validation_data=(X_test, y_test), 
          callbacks = [print_weights])

приведенный выше код должен печатать ваши веса встраивания model.layers[0].get_weights() в конце каждой эпохи. Вам решать, распечатать ли его там, где вы хотите сделать его читабельным, сбросить его в файл рассола,...

Надеюсь это поможет

person Nassim Ben    schedule 06.02.2017
comment
Спасибо за ваш ответ, но если я хочу сохранить все веса в списке, а не распечатать его, как я могу это сделать? Я пробовал logs[weights].append(model.layers[0].get_weights() но это не работает - person jimmy15923; 05.06.2017
comment
@jimmy15923 model.layers[0].get_weights() показывает только вес первого слоя, что ничего не значит, учитывая, что он предназначен для ввода. Вам нужно перебрать все слои. - person Andy Wei; 22.05.2018
comment
опечатка: вы должны печатать эпоху, а не партию: print_weights = LambdaCallback(on_epoch_end=lambda epoch, ... - person MAltakrori; 04.10.2018