Пакетная нормализация была представлена ​​в статье Сергея Иоффе и Кристиана Сегеди 2015 года Пакетная нормализация: ускорение обучения глубокой сети за счет уменьшения внутреннего ковариатного сдвига. Идея состоит в том, что вместо того, чтобы просто нормализовать входные данные для сети, мы нормализуем входные данные для слоев сети.

Это называется пакетной нормализацией, поскольку во время обучения мы нормализуем входные данные каждого слоя, используя среднее значение и дисперсию значений в текущей партии.

Изменение распределения входных данных слоев представляет собой проблему, поскольку слоям необходимо постоянно адаптироваться к новому распределению. Когда распределение входных данных в системе обучения изменяется, говорят, что в ней происходит ковариативный сдвиг.

Пакетная нормализация в PyTorch

В этом разделе кода показан один из способов добавления пакетной нормализации в нейронную сеть, созданную в PyTorch.

Следующий код импортирует нужные нам пакеты в записную книжку и загружает набор данных MNIST для использования в наших экспериментах.

Визуализация данных

Классы нейронных сетей для тестирования

Следующий класс, NeuralNet, позволяет нам создавать идентичные нейронные сети с пакетной нормализацией и без нее для сравнения. Код тщательно документирован, но позже будет проведено дополнительное обсуждение. Вам не нужно читать все, прежде чем перейти к остальной части блокнота, но комментарии в блоках кода могут ответить на некоторые ваши вопросы.

О коде:

Мы определяем простой MLP для классификации; этот выбор дизайна был сделан для поддержки обсуждения, связанного с нормализацией пакетов, а не для достижения максимальной точности классификации.

(Важно) Детали модели

В коде довольно много комментариев, поэтому они должны ответить на большинство ваших вопросов. Тем не менее, давайте взглянем на самые важные строки.

Мы добавляем пакетную нормализацию к слоям внутри функции __init__. Вот несколько важных моментов об этом коде:

  1. Слои с пакетной нормализацией не содержат термин смещения.
  2. Мы используем функцию PyTorch BatchNorm1d для обработки математики. Это функция, которую вы используете для работы с выходными данными линейного слоя; вы будете использовать BatchNorm2d для вывода 2D, например отфильтрованных изображений из сверточных слоев.
  3. Мы добавляем слой пакетной нормализации перед вызовом функции активации.

Создайте две разные модели для тестирования

  • net_batchnorm – это модель линейной классификации, с пакетной нормализацией, применяемой к выходным данным ее скрытых слоев.
  • net_no_norm — это обычный MLP, без пакетной нормализации

Помимо слоев нормализации, в этих моделях все одинаково.

net_batchnorm = NeuralNet(use_batch_norm=True)
net_no_norm = NeuralNet(use_batch_norm=False)

print(net_batchnorm)
print()
print(net_no_norm)

Обучение

Приведенная ниже функция train будет принимать модель и некоторое количество эпох. Мы будем использовать перекрестную потерю энтропии и стохастический градиентный спуск для оптимизации. Эта функция возвращает потери, записанные после каждой эпохи, чтобы мы могли отображать и сравнивать поведение разных моделей.

.train() режим

Обратите внимание, что мы сообщаем нашей модели, должна ли она находиться в режиме обучения, model.train(). Это важный шаг, потому что нормализация партии ведет себя по-разному во время обучения партии или тестирования/оценки большого набора данных.

Сравнение моделей

В приведенных ниже ячейках мы обучаем две разные модели и сравниваем их потери при обучении с течением времени.

# batchnorm model losses
# this may take some time to train
losses_batchnorm = train(net_batchnorm)
# *no* norm model losses
# you should already start to see a difference in training losses
losses_no_norm = train(net_no_norm)

Давайте посмотрим на разницу между моделями с пакетной нормой и без нее на графике.

# compare
fig, ax = plt.subplots(figsize=(12,8))
#losses_batchnorm = np.array(losses_batchnorm)
#losses_no_norm = np.array(losses_no_norm)
plt.plot(losses_batchnorm, label='Using batchnorm', alpha=0.5)
plt.plot(losses_no_norm, label='No norm', alpha=0.5)
plt.title("Training Losses")
plt.legend()

Тестирование

Вы должны увидеть, что модель с пакетной нормализацией начинается с более низких потерь при обучении и через десять эпох обучения достигает потерь при обучении, которые заметно ниже, чем у нашей модели без нормализации.

Далее давайте посмотрим, как обе эти модели работают на наших тестовых данных! Ниже у нас есть функция test, которая принимает модель, и параметр train (True или False), который указывает, должна ли модель находиться в режиме обучения или в режиме оценки. Это для сравнения, позже. Эта функция будет вычислять некоторые статистические данные теста, включая общую точность теста переданной модели.

Режим обучения и оценки

Установка модели в режим оценки важна для моделей со слоями пакетной нормализации!

  • Режим обучения означает, что слои нормализации партии будут использовать статистику пакета для расчета нормы пакета.
  • Режим оценки, с другой стороны, использует предполагаемое среднее значение населения и дисперсию всего обучающего набора, что должно повысить производительность на этих тестовых данных!
# test batchnorm case, in *train* mode
test(net_batchnorm, train=True)
Test Loss: 0.086881

Test Accuracy of     0: 98% (967/980)
Test Accuracy of     1: 99% (1126/1135)
Test Accuracy of     2: 96% (999/1032)
Test Accuracy of     3: 97% (989/1010)
Test Accuracy of     4: 96% (952/982)
Test Accuracy of     5: 96% (864/892)
Test Accuracy of     6: 97% (933/958)
Test Accuracy of     7: 96% (990/1028)
Test Accuracy of     8: 96% (939/974)
Test Accuracy of     9: 95% (966/1009)

Test Accuracy (Overall): 97% (9725/10000)
# test batchnorm case, in *evaluation* mode
test(net_batchnorm, train=False)
Test Loss: 0.073484

Test Accuracy of     0: 98% (968/980)
Test Accuracy of     1: 99% (1127/1135)
Test Accuracy of     2: 97% (1005/1032)
Test Accuracy of     3: 98% (991/1010)
Test Accuracy of     4: 97% (955/982)
Test Accuracy of     5: 97% (874/892)
Test Accuracy of     6: 97% (932/958)
Test Accuracy of     7: 96% (995/1028)
Test Accuracy of     8: 96% (940/974)
Test Accuracy of     9: 97% (983/1009)

Test Accuracy (Overall): 97% (9770/10000)
# for posterity, test no norm case in eval mode
test(net_no_norm, train=False)
Test Loss: 0.207286

Test Accuracy of     0: 98% (963/980)
Test Accuracy of     1: 98% (1113/1135)
Test Accuracy of     2: 91% (943/1032)
Test Accuracy of     3: 93% (943/1010)
Test Accuracy of     4: 93% (918/982)
Test Accuracy of     5: 92% (824/892)
Test Accuracy of     6: 95% (912/958)
Test Accuracy of     7: 92% (954/1028)
Test Accuracy of     8: 91% (891/974)
Test Accuracy of     9: 93% (940/1009)

Test Accuracy (Overall): 94% (9401/10000)

Какая модель имеет наибольшую точность?

Вы должны увидеть небольшое улучшение при сравнении точности модели норм партии в режиме обучения и оценки; режим оценки должен дать небольшое улучшение!

Вы также должны увидеть, что модель, использующая слои пакетной нормы, демонстрирует заметное улучшение общей точности по сравнению с моделью без нормализации.

Рекомендации для других типов сетей

В этой записной книжке демонстрируется пакетная нормализация в стандартной нейронной сети с полностью связанными слоями. Вы также можете использовать пакетную нормализацию в других типах сетей, но есть некоторые особенности.

ConvNets

Слои свертки состоят из нескольких карт объектов. (Помните, что глубина сверточного слоя относится к количеству карт объектов.) И веса для каждой карты объектов являются общими для всех входных данных, поступающих в слой. Из-за этих различий для пакетной нормализации сверточных слоев требуется среднее значение партии/населения и дисперсия для каждой карты объектов, а не для каждого узла в слое.

Чтобы применить пакетную нормализацию к выходным данным сверточных слоев, мы используем BatchNorm2d

RNN

Пакетная нормализация также может работать с рекуррентными нейронными сетями, как показано в статье 2016 года Рекуррентная пакетная нормализация. Это немного больше работы для реализации, но в основном включает в себя расчет средних значений и дисперсий на временной шаг, а не на слой. Вы можете найти пример, когда кто-то реализовал рекуррентную пакетную нормализацию в PyTorch, в этом репозитории GitHub.