Быстрое обучение - это захватывающая область машинного обучения, цель которой - сократить разрыв между машиной и человеком в сложной задаче обучения на нескольких примерах. В моем предыдущем посте я представил высокоуровневое резюме трех передовых работ по изучению с несколькими выстрелами - я предполагаю, что вы либо читали это, либо уже знакомы с этими статьями, либо находитесь в процессе их воспроизведения.

В этом посте я расскажу вам о своем опыте воспроизведения результатов этих работ по наборам данных Omniglot и miniImageNet, включая некоторые подводные камни и препятствия на этом пути. В каждой статье есть свой раздел, в котором я представляю Github gist с кодом PyTorch для выполнения обновления одного параметра в модели, описанной в статье. Чтобы обучить модель, просто нужно поместить эту функцию в цикл над обучающими данными. Менее интересные детали, такие как обработка наборов данных, для краткости опущены.

Воспроизводимость очень важна, это основа любой области, претендующей на звание научной. Это заставляет меня поверить, что распространенность совместного использования кода и открытого исходного кода в машинном обучении действительно достойна восхищения. Хотя публикация кода сама по себе не является воспроизводимостью (поскольку могут быть ошибки реализации), она открывает методы исследователей для общественного контроля и, что более важно, ускоряет исследования других в этой области. В свете этого я хотел бы поблагодарить авторов этих документов за то, что они поделились своим кодом, а также всех других, кто предоставил открытый исходный код для своих реализаций.

Полную реализацию см. в моем репозитории Github по адресу https://github.com/oscarknagg/few-shot

Наборы данных

Есть два набора данных изображений, на которых оцениваются алгоритмы обучения с несколькими выстрелами. Первый - это набор данных Omniglot, который содержит 20 изображений, каждое примерно по 1600 символов из 50 алфавитов. Эти изображения обычно имеют оттенки серого 28x28, что является одной из причин, по которой этот набор данных часто называют транспонированием MNIST.

Второй - это набор данных miniImageNet, подмножество ImageNet, призванное стать более сложным тестом, но не таким громоздким, как полный набор данных ImageNet. miniImageNet состоит из 60 000 изображений RGB 84x84, по 600 изображений на класс.

В обоих случаях классы в обучающем и проверочном наборах не пересекаются. Я не использовал те же обучающие и проверочные разбиения, что и в исходных статьях, поскольку моя цель - не воспроизвести их до мельчайших деталей.

Соответствующие сети

В Matching Networks Виньялс и др. Вводят идею полностью дифференцируемого классификатора ближайших соседей, который одновременно обучается и тестируется на задачах с малым числом выстрелов.

Алгоритм Matching Networks можно резюмировать следующим образом:

  1. Сначала вставьте все образцы (запрос и набор поддержки), используя сеть кодировщика (в данном случае 4-уровневую CNN). Это выполняется model.encode () (строка 41).
  2. Необязательно вычислить встраивание полного контекста (FCE). LSTM принимает исходные вложения как входные и выходные модифицированные вложения с учетом набора поддержки. Это выполняется с помощью model.f () и model.g () (строки 62 и 67).
  3. Вычислить попарные расстояния между выборками запросов и вспомогательными наборами и нормализовать с помощью softmax (строки с 69 по 77)
  4. Вычислите прогнозы, взяв средневзвешенное значение меток набора опор с нормализованным расстоянием (строки 83–89).

Несколько замечаний:

  • В этом примере тензор x содержит сначала образцы набора поддержки, а затем запрос. Для Омниглота он будет иметь форму (n_support + n_query, 1, 28, 28)
  • Математические расчеты в предыдущем сообщении относятся к одному образцу запроса, но сети сопоставления фактически обучаются с помощью пакета образцов запроса размером q_queries * k_way

Мне не удалось воспроизвести результаты этой статьи с использованием косинусного расстояния, но я добился успеха с использованием расстояния l2. Я считаю, что это связано с тем, что косинусное расстояние ограничено между -1 и 1, что затем ограничивает количество, которое функция внимания (a (x ^, x_i) ниже) может указывать на конкретный образец в наборе поддержки. . Поскольку косинусное расстояние ограничено, a (x ^, x_i) никогда не будет близко к 1! В случае 5-позиционной классификации максимально возможное значение a (x ^, x_i) равно exp (1) / (exp (1) + 4 * exp (-1)) ≈ 0,65. Это привело к очень медленной сходимости при использовании косинусного расстояния.

Я думаю, что можно воспроизвести результаты с использованием косинусного расстояния либо с более длительным временем обучения, либо с лучшими гиперпараметрами, либо с помощью эвристики, такой как умножение косинусного расстояния на постоянный коэффициент. Поскольку выбор расстояния не является ключевым для статьи, а результаты очень хороши с использованием расстояния l2, я решил избавить себя от этих усилий по отладке.

Прототипные сети

В Prototypical Networks Снелл и др. Используют убедительное индуктивное смещение, мотивированное теорией расходимостей Брегмана, для достижения впечатляющей производительности с несколькими выстрелами.

Алгоритм прототипической сети можно резюмировать следующим образом:

  1. Встроить все образцы запросов и поддержки (строка 36)
  2. Вычислить прототипы классов, взяв среднее значение вложений каждого класса (строка 48).
  3. Прогнозы - это softmax на расстояниях между выборками запросов и прототипами классов (строка 63).

Мне показалось, что эту статью легко воспроизвести, поскольку авторы предоставили полный набор гиперпараметров. Следовательно, я легко смог достичь заявленной производительности с точностью до ~ 0,2% в тесте Omniglot и в пределах нескольких% в тесте miniImageNet без необходимости выполнять какие-либо собственные настройки.

Метаагностическое метаобучение (MAML)

В MAML Финн и др. Представляют мощный и широко применимый алгоритм метаобучения для изучения инициализации сети, которая может быстро адаптироваться к новым задачам. Эта статья была самой сложной, но наиболее полезной для воспроизведения из трех, представленных в этой статье.

Алгоритм MAML можно резюмировать следующим образом:

  1. Для каждой задачи n-shot в метапакете задач создайте новую модель, используя веса базовой модели AKA meta-Learner (строка 79).
  2. Обновите веса новой модели, используя потери из выборок в задаче путем стохастического градиентного спуска (строки 81–92).
  3. Рассчитайте потери обновленной модели на еще нескольких данных из той же задачи (строки 94–97).
  4. При выполнении MAML 1-го порядка обновите веса мета-учащихся с градиентом потерь из части 3. При выполнении MAML 2-го порядка вычислите производную этой потери по исходным весам (строки 110+)

Самая большая привлекательность PyTorch - это его система автоградации. Это часть магии кода, которая записывает операции, действующие на объекты torch.Tensor, и динамически строит направленный ациклический граф этих операций под капотом. Обратное распространение так же просто, как вызов .backwards () для конечного результата. Мне пришлось немного больше узнать об этой системе, чтобы вычислить и применить обновления параметров к мета -ученик, которым я сейчас с вами поделюсь.

MAML 1-го порядка - замена градиента

Обычно при обучении модели в PyTorch вы создаете объект оптимизатора, привязанный к параметрам конкретной модели.

from torch.optim import Adam
opt = Adam(model.parameters(), lr=0.001)

Когда вызывается opt.step (), оптимизатор считывает градиенты параметров модели и вычисляет обновление этих параметров. Однако в MAML 1-го порядка мы собираемся вычислить градиенты, используя одну модель (быстрые веса), и применить обновление к другой модели, то есть метаобучающемуся.

Решением этой проблемы является использование недостаточно используемой части функциональных возможностей PyTorch в виде torch.Tensor.register_hook (hook). Зарегистрируйте функцию-перехватчик в тензоре, и эта функция-перехватчик будет вызываться всякий раз, когда вычисляется градиент по отношению к этому тензору. Для каждого параметра Tensor в метаученике я регистрирую ловушку, которая просто заменяет градиент на соответствующий градиент на быстрых весах (строки 111–129 в сущности). Это означает, что при вызове opt.step () градиенты быстрой модели будут использоваться для обновления весов мета-учащихся по желанию.

MAML 2-го порядка - вопросы автограда

При первой попытке реализации MAML я создал экземпляр нового объекта модели (подкласс torch.nn.Module) и установил значения его весов, равных весам метаученика. Однако это делает невозможным выполнение MAML 2-го порядка, так как веса быстрой модели не связаны с весами метаученика в глазах torch.autograd. Это означает, что когда я вызываю optimiser.step () (строка 140 в сущности), граф автоградации для весов мета-учащихся пуст и обновление не выполняется.

# This didn't work, meta_learner weights remain unchanged
meta_learner = ModelClass()
opt = Adam(meta_learner.parameters(), lr=0.001)
task_losses = []
for x, y in meta_batch:
    fast_model = ModelClass()
    # torch.autograd loses reference here!
    copy_weights(from=meta_learner, to=fast_model)
    task_losses.append(update_with_batch(x, y))
meta_batch_loss = torch.stack(task_losses).mean()
meta_batch_loss.backward()
opt.step()

Решением этого является function_forward () (строка 17), который представляет собой несколько неудобный прием, который вручную выполняет те же операции (свертка, максимальное объединение и т. Д.), Что и класс модели с использованием torch. нн. функциональный. Это также означает, что мне нужно вручную выполнить обновление параметров быстрой модели. Следствием этого является то, что torch.autograd знает, как распространять градиенты в обратном направлении до исходных весов мета-учащегося. Это приводит к очень большому графику автограда.

Однако MAML 2-го порядка - более хитрый зверь, чем это. Когда я впервые написал свою реализацию MAML 2-го порядка, я думал, что все заработало чудесным образом с первой попытки. По крайней мере, не было никаких исключений, верно? Только после выполнения полного набора экспериментов Omniglot и miniImageNet я начал сомневаться в своей работе - результаты были слишком похожи на MAML 1-го порядка. Это типично для неудачной породы тихих ошибок машинного обучения, которые не вызывают исключений, а становятся видимыми только в окончательной работе модели.

Поэтому я решил взять себя в руки и написать модульный тест, который подтвердил бы, что я действительно выполняю обновление 2-го порядка. Отказ от ответственности: в духе настоящей разработки, основанной на тестировании, я должен был написать этот тест перед проведением каких-либо экспериментов. Оказывается, я не идеален 😛.

Тест, который я выбрал, заключался в том, чтобы запустить функцию meta_gradient_step на фиктивной модели и вручную проанализировать график автограда, подсчитав количество двойных операций в обратном направлении. Таким образом, я могу быть абсолютно уверен, что при желании выполняю обновление 2-го порядка. И наоборот, я смог проверить, что моя реализация MAML 1-го порядка выполняет только обновление 1-го порядка без двойных операций в обратном направлении.

Наконец, я обнаружил ошибку, не применяющую параметр create_graph во внутреннем цикле обучения (строка 86). Тем не менее, я сохранил автографический график потерь в выборках запросов (строка 97), но этого было недостаточно для выполнения обновления 2-го порядка, поскольку развернутый обучающий граф не был создан.

Время обучения было довольно долгим (более 24 часов для эксперимента miniImageNet с 5 кадрами и 5 кадрами), но в конце концов я добился довольно хороших результатов в воспроизведении результатов.

Я надеюсь, что вы узнали что-то полезное из этого глубокого технического погружения. Если у вас есть вопросы, дайте мне знать в комментариях.

Если вы хотите вникнуть в подробности, ознакомьтесь с полной реализацией на https://github.com/oscarknagg/few-shot