- Загрузка модели ResNet и добавление L2 Regularization:
resnet_base = ResNet50(weights='imagenet', include_top=False, input_shape=(224,224,3))
alpha = 1e-5
for layer in resnet_base.layers:
if isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Dense):
layer.add_loss(keras.regularizers.l2(alpha)(layer.kernel))
if hasattr(layer, 'bias_regularizer') and layer.use_bias:
layer.add_loss(keras.regularizers.l2(alpha)(layer.bias))
- Добавляем слой FC поверх базы ResNet:
model = models.Sequential()
model.add(resnet_base)
model.add(layers.AveragePooling2D())
model.add(layers.Flatten())
model.add(layers.Dense(128, activation= 'relu', kernel_regularizer=keras.regularizers.l2(alpha)))
model.add(layers.Dropout(0.6))
model.add(layers.Dense(8, activation = 'softmax', kernel_regularizer=keras.regularizers.l2(alpha)))
model.summary()
- Замораживание слоев ResNet для разогрева слоя FC:
for layer in resnet_base.layers[:]:
layer.trainable = False
model.compile(optimizer = SGD(learning_rate = 0.0001, momentum = 0.9, nesterov = False), loss='categorical_crossentropy', metrics=['accuracy'])
batch_size = 32
history = model.fit(train_generator,
steps_per_epoch=14206//batch_size, #14206 - training samples
epochs=5,
validation_data=validation_generator,
validation_steps=3546//batch_size) #3546 - validation samples
4. Разморозка некоторых слоев базы ResNet и повторное обучение модели:
for layer in resnet_base.layers[:165]:
layer.trainable = False
for layer in resnet_base.layers[165:]:
layer.trainable = True
model.compile(optimizer = SGD(learning_rate = 0.0001, momentum = 0.9, nesterov = False),
loss='categorical_crossentropy',
metrics=['accuracy'])
nepochs=150
history = model.fit(train_generator,
steps_per_epoch=14206//batch_size, #14206 - training samples
epochs=nepochs,
validation_data=validation_generator,
validation_steps=3546//batch_size) #3546 - validation samples
При всем этом я не могу решить проблему переобучения. Я увеличил обучающие данные и использовал функцию preprocess_input как для обучения, так и для набора данных проверки. Я выполнил инструкции, представленные здесь: https://jricheimer.github.io/keras/2019/02/06/keras-hack-1/ для реализации регуляризации L2.