В этом посте мы увидим, как разработать полную систему машинного перевода с нуля с помощью пакета MXNet R. Мы набираем балл BLEU более 28 за задание с английского на французский с помощью одной модели, обученной в течение дня на одном графическом процессоре, не полагаясь на какие-либо внешние ресурсы, такие как предварительно обученные векторы слов или токенизаторы. Полный код для воспроизведения модели можно найти в репозитории translatR.

Получение данных

Первый шаг к построению модели перевода - это сбор и подготовка данных. К счастью, WMT сделал доступным большой параллельный корпус.

В этой демонстрации будет использоваться корпус Europarl v7 и Common Crawl, обеспечивающий более 4 миллионов пар предложений.

download.file(url = "http://www.statmt.org/europarl/v7/fr-en.tgz", destfile = "./data/europarl_fr-en.tgz")
untar(tarfile = "./data/europarl_fr-en.tgz”, exdir = "./data/")
euro_en <- read_lines("data/europarl-v7.fr-en.en")
euro_fr <- read_lines("data/europarl-v7.fr-en.fr")

Вышеупомянутый импорт приводит к двум большим векторам символов (для английского и французского языков) одинакового размера, где каждый элемент является предложением для Europarl. То же самое повторяется для обычного сканирования, и результаты объединяются.

Предварительная обработка

Модель будет обучаться на последовательностях слов. Поскольку обучение будет выполняться путем подачи в модель массивов фиксированного размера, целью предварительной обработки будет выполнение следующего преобразования:

Вышеупомянутое включает в себя несколько шагов. Во-первых, это токенизация, при которой строка, содержащая всю последовательность, разбивается на вектор токенов.

Чтобы быть легкими и вводить как можно меньшее количество предварительных знаний языка, последовательности просто разделяются на пустые места. Чтобы модель могла обрабатывать знаки препинания, нужно быстро взломать их, вставив вокруг них пробелы. Наконец, маркеры-заполнители <BOS> и <EOS> объединяются, чтобы предоставить модели начальные и конечные сигналы.

source <- "I'd love to learn French!"
source <- gsub("([[:punct:]])", " \\1 ", source)
source <- paste("<BOS>", source, "<EOS>")
> strsplit(source, "\\s+")
[[1]]
 [1] "<BOS>"  "I"      "'"      "d"      "love"   "to"     "learn"  "French" "!"      "<EOS>"

Вышеупомянутый вектор токенов затем преобразуется в таблицу data.table вместе с последовательностью и идентификатором слова:

source_dt <- data.table(word = unlist(source_word_vec_list), 
                        seq_id = rep(1:length(source_seq_length),
                                     times = source_seq_length),
                        seq_word_id = seq_word_id_source)
> source_dt
      word seq_id seq_word_id
 1:  <BOS>      1           1
 2:      I      1           2
 3:      '      1           3
 4:      d      1           4
 5:   love      1           5
 6:     to      1           6
 7:  learn      1           7
 8: French      1           8
 9:      !      1           9
10:  <EOS>      1          10

Формат data.table эффективен для создания словаря для сопоставления каждого токена с индексом. Жетоны подсчитываются, а редкие игнорируются, чтобы ограничить размер словарного запаса в диапазоне от 20 до 50 тысяч.

source_word_count <- source_dt[, .N, by = word]
source_dic <- source_word_count[N >= word_count_min,,][order(-N)]

Также введены два других специальных токена: <PAD> и <UNKNOWN>. Первый используется для заполнения последовательностей короче, чем матрица данных, а второй используется по умолчанию для токенов, отсутствующих в словаре.

После того, как словарь построен, оставшийся шаг - проиндексировать его по указанной выше таблице data.table и преобразовать ее в таблицу размера [number of sequence, max sequence length] с помощью dcast, установив общую длину последовательности:

source_dt <- source_dic[source_dt][order(seq_id, seq_word_id)]
source_dt <- dcast(data = source_dt, seq_word_id ~ seq_id, value.var = "word_id", fill = 0)
source <- as.matrix(source[ , c("seq_word_id") := NULL])
> source_dt
         1    2
 [1,]    5    5
 [2,]   22   22
 [3,]   21   21
 [4,]  550  550
 [5,] 1258 1258
 [6,]    8    8
 [7,] 1161 1161
 [8,]  424  424
 [9,]   86   10
[10,]    6 1645
[11,]    0   86
[12,]    0    6

Исходный текст теперь преобразован в матрицу размеров [максимальная длина последовательности, количество последовательностей], готовая для использования в модели перевода. Первый столбец соответствует индексу, показанному на начальном рисунке предварительной обработки.

Архитектура

Базовая модель от последовательности к последовательности может быть представлена ​​как:

В этой конструкции кодировщику необходимо предоставить единый вектор функций, который несет всю информацию о последовательности, что привело к шутке, когда из этого единственного вектора можно было извлечь неограниченное количество клоунов. Интуиция о сверхъестественных способностях действительно имеет силу: более эффективные разработки позволяют избежать ограничений, накладываемых сжатием кодирования фиксированной длины, полагаясь на хитрый трюк: внимание.

Больше никаких клоунов! Благодаря механизму внимания данные, поступающие в декодер, теперь представляют собой всю закодированную последовательность, устраняя информационное узкое место предыдущей архитектуры.

Внимание может быть реализовано во многих вариантах. Для обеспечения гибкости моя реализация использовала абстрактную нотацию Query-Key-Value, которая описана в статье Внимание - это все, что вам нужно. Преимуществом такого подхода является модульность системы перевода. Кодер предоставляет как матрицу значений, так и матрицу ключей, причем последняя является преобразованием первой, а декодер предоставляет запрос, который является вектором характеристик декодированного токена.

Модуль внимания вернет средневзвешенное значение матрицы значений (вектора внимания), которое будет использоваться для улучшения вектора функций во время декодирования.

Возможны несколько вариантов парадигмы запроса-ключа-значения. Bilinear и MLP были реализованы в дополнение к Dot Внимание, показанному ниже:

attn <- attn_dot(value=value, query_key_size=num_hidden, scale=T)
init <- attn$init()
attend <- attn$attend
attention <- attend(query=query, key=init$key, value=init$value, attn_init=init)

Выше приведен фактический график MXNet для точки внимания, где размер пакета равен 128 (последнее измерение). Для каждого токена, который должен быть декодирован, query является перепроецированием 512-длинного представления этого токена. value - это полная кодировка исходной последовательности. Он сам перепроецируется для формирования key, к которому применяется скалярное произведение в запросе, чтобы получить схему взвешивания, которая будет применяться к матрице value. Результирующий вектор длиной 512 называется вектором контекста, который затем добавляется к исходной кодировке токена для вычисления оценки, связанной с каждым словом целевого словаря.

Последний компонент модели - функция потерь softmax. Он нормализует вышеуказанные оценки в распределение вероятностей и использует функцию потерь кросс-энтропии для получения градиента напора для распространения.

Обучение

Благодаря модульной конструкции кодер-внимание-декодер полная модель может быть легко построена. Еще одна тонкость заключается в том, что во время обучения декодер использует учителя. То есть на каждом шаге подаётся истинный предыдущий токен, а не предсказанный. Такая информация недоступна при выполнении вывода. Поэтому создается второй декодер, который использует наиболее вероятное слово при выводе (argmax по предсказаниям), а не истинную метку.

Гиперпараметры для обучения были довольно ванильными: оптимизатор Adam с уменьшающейся скоростью обучения:

initializer <- mx.init.Xavier(rnd_type = "uniform", factor_type = "in", magnitude = 2.5)
lr_scheduler <- mx.lr_scheduler.FactorScheduler(step = 5000,     factor_val = 0.9, stop_factor_lr = 5e-5)
optimizer <- mx.opt.create("adam", learning.rate = 5e-4, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8, wd = 1e-8, clip_gradient = 1, rescale.grad = 1, lr_scheduler = lr_scheduler)

Затем модель обучалась в течение 8 эпох, что занимало около суток на графическом процессоре V100.

model <- mx.model.buckets(symbol = decode_teacher,
                          train.data = iter_train, 
                          eval.data = iter_eval,
                          num.round = 12, ctx = ctx, verbose = TRUE,
                          metric = mx.metric.Perplexity, 
                          optimizer = optimizer,  
                          initializer = initializer,
                          batch.end.callback = batch.end.callback, 
                          epoch.end.callback = epoch.end.callback)
mx.model.save(model=model, prefix="models/en_fr_cnn_rnn_teacher", iteration = 8)
mx.symbol.save(symbol=decode_argmax, filename="models/en_fr_cnn_rnn_argmax.json")

Недоумение используется как метрика оценки для отслеживания прогресса обучения:

Вывод

Чтобы получить сопоставимую оценку качества перевода, модель можно сравнить с официальным набором тестов WMT. Для этого пригодится библиотека sacreBLEU:

sacrebleu --test-set wmt15 --language-pair en-fr --echo src > wmt15-en-fr.src

При выполнении логического вывода для нового набора данных крайне важно применять ту же предварительную обработку, что и для обучающих данных. К счастью, в нашем сценарии было применено очень мало преобразований, что сделало этот шаг легко воспроизводимым на wmt15-en-fr.src данных. Очевидно, что должен применяться тот же словарь, поэтому тот, который был разработан для обучения, будет использоваться на этапе предварительной обработки, а не создавать новый «на лету».

Модель вывода получается путем объединения структуры argmax с весами, полученными во время обучения с teacher.

model <- mx.model.load(prefix = "models/model_wmt15_en_fr_cnn_rnn_teacher_v2", iteration = 12)
sym_infer <- mx.symbol.load(file.name = "models/model_wmt15_en_fr_cnn_rnn_argmax_v2.json")
model_infer <- list(symbol = sym_infer, arg.params = model$arg.params, aux.params = model$aux.params)
model_infer <- structure(model_infer, class="MXFeedForwardModel")

Затем вывод может быть применен к тестовым данным, хранящимся в виде текстового файла, готовым для оценки с помощью sacreBLEU:

cat wmt15_en_fr_cnn_rnn.txt | sacrebleu -t wmt15 -l en-fr

Итоговая сводка по производительности должна выглядеть примерно так:

BLEU+case.mixed+lang.en-fr+numrefs.1+smooth.exp+test.wmt15+tok.13a+version.1.2.1                                                            2 = 28.2 61.0/36.2/23.8/16.1 (BP = 0.930 ratio = 0.933 hyp_len = 26090 ref_len =                                                             27975)

Это означает, что мы достигли 28,2 балла по шкале BLEU.

Тестовые предложения также могут быть отправлены на перевод для проверки правильности модели:

> infer_helper(infer_seq = "I'd love to learn French!",
               model = model_infer, 
               source_dic = source_dic, 
               target_dic = target_dic, 
               seq_len = seq_len)
[1] "J'aimerais apprendre le français!"

Улучшение

Чтобы достичь высочайшего уровня производительности, можно рассмотреть несколько дополнительных приемов.

  • Токенизация: наиболее эффективные системы обычно используют более сложные схемы токенизации, особенно BPE, который создает разделение на подслова.
  • Позиционное встраивание: в дополнение к идентификаторам токенов, которые используются в качестве входных данных, могут быть добавлены функции, которые представляют положение токена в предложении. Это может быть либо одиночный индикатор положения (абсолютный или относительный), либо более сложный набор синусоидальных / косинусных волн, как используется в модели трансформатора.
  • Объединение моделей: усредните прогнозы нескольких моделей.
  • Поиск луча: вместо использования одного наилучшего транслированного токена генерируются N лучших кандидатов и связанный с ними следующий лучший шаг, а токен, связанный с путем максимального правдоподобия, сохраняется. Это частично обходит ограничения жадного декодирования argmax.

Многие из этих функций интегрированы в обширную библиотеку Sockeye, построенную на основе MXNet.