Как сообщить Керасу о прекращении тренировок на основе величины потерь?

В настоящее время я использую следующий код:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Он говорит Керасу прекратить тренировку, если потери не улучшились в течение 2 эпох. Но я хочу прекратить тренировку после того, как потеря стала меньше некоторого постоянного «THR»:

if val_loss < THR:
    break

Я видел в документации, что есть возможность сделать свой собственный обратный вызов: http://keras.io/callbacks/ Но ничего не нашел, как остановить тренировочный процесс. Мне нужен совет.


person ZFTurbo    schedule 18.05.2016    source источник


Ответы (7)


Я нашел ответ. Я заглянул в исходники Keras и нашел код для EarlyStopping. Я сделал свой обратный вызов, основываясь на нем:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

И использование:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)
person ZFTurbo    schedule 18.05.2016
comment
На всякий случай, если это будет кому-то полезно - в моем случае я использовал monitor = 'loss', это сработало. - person QtRoS; 18.02.2017
comment
Кажется, Керас обновился. В функцию обратного вызова EarlyStopping теперь встроена min_delta. Больше не нужно взламывать исходный код, ура! stackoverflow.com/a/41459368/3345375 - person jkdev; 21.06.2017
comment
Перечитав вопрос и ответы, мне нужно исправить себя: min_delta означает «Остановить раньше», если не хватает улучшений за эпоху (или за несколько эпох). Тем не менее, OP спросил, как остановить досрочно, когда убыток становится ниже определенного уровня. - person jkdev; 22.06.2017
comment
NameError: имя Callback не определено ... Как исправить? - person alyssaeliyah; 26.11.2018
comment
Элия, попробуй это: from keras.callbacks import Callback - person ZFTurbo; 28.11.2018
comment
одно исправление должно быть elif elif current ‹self.value: - person Cathy; 29.03.2021

Обратный вызов keras.callbacks.EarlyStopping имеет аргумент min_delta. Из документации Keras:

min_delta: минимальное изменение отслеживаемого количества, которое квалифицируется как улучшение, то есть абсолютное изменение меньше min_delta не будет считаться улучшением.

person devin    schedule 04.01.2017
comment
Для справки, вот документы для более ранней версии Keras (1.1.0), в которой аргумент min_delta еще не был включен: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping - person jkdev; 21.06.2017
comment
как я могу сделать так, чтобы это не прекратилось, пока min_delta не будет сохраняться в течение нескольких эпох? - person zyxue; 18.04.2018
comment
есть еще один параметр EarlyStopping, называемый терпением: количество эпох без улучшений, после которых обучение будет остановлено. - person devin; 19.04.2018

Одно из решений - вызвать model.fit(nb_epoch=1, ...) внутри цикла for, затем вы можете поместить оператор break внутри цикла for и выполнить любой другой настраиваемый поток управления, который вы хотите.

person 1''    schedule 18.05.2016
comment
Было бы неплохо, если бы они сделали обратный вызов, который принимает единственную функцию, которая может это сделать. - person Honesty; 26.08.2016

Я решил ту же проблему, используя собственный обратный вызов.

В следующем пользовательском коде обратного вызова присвойте THR значение, при котором вы хотите остановить обучение, и добавьте обратный вызов в свою модель.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True
person Rushin Tilva    schedule 02.03.2019

Пока я проходил специализацию на практике TensorFlow, я изучил очень элегантную технику. Просто немного изменен из принятого ответа.

Давайте рассмотрим пример с нашими любимыми данными MNIST.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Итак, здесь я установил metrics=['accuracy'], и, таким образом, в классе обратного вызова условие установлено на 'accuracy'> 0.90.

Вы можете выбрать любую метрику и следить за тренировкой, как в этом примере. Самое главное, вы можете установить разные условия для разных метрик и использовать их одновременно.

Надеюсь, это поможет!

person Suvo    schedule 13.04.2020
comment
имя функции должно быть on_epoch_end - person xarion; 28.08.2020

Для меня модель остановит обучение только в том случае, если я добавлю оператор return после установки для параметра stop_training значения True, потому что я звонил после self.model.evaluate. Так что либо обязательно поставьте stop_training = True в конце функции, либо добавьте оператор возврата.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here
person Juan Antonio Barragan    schedule 10.04.2020

Если вы используете настраиваемый цикл обучения, вы можете использовать collections.deque, который представляет собой скользящий список, который можно добавлять, и левые элементы выскакивают, когда список длиннее maxlen. Вот строчка:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

Вот полный пример:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
person Nicolas Gervais    schedule 17.08.2020