Tensorflow реализация word2vec

Учебное пособие по Tensorflow здесь относится к их базовой реализации, которую вы можете найти на github здесь, где авторы Tensorflow реализуют обучение/оценку встраивания вектора word2vec с помощью модели Skipgram.

Мой вопрос касается фактического создания пар (цель, контекст) в функции generate_batch().

В этой строке Авторы Tensorflow случайным образом выбирают близлежащие целевые индексы из «центрального» индекса слов в скользящем окне слов.

Однако они также хранят данные структура targets_to_avoid, к которой они добавляют сначала «центральное» контекстное слово (которое, конечно, мы не хотим сэмплировать), но ТАКЖЕ другие слова после их добавления.

Мои вопросы заключаются в следующем:

  1. Зачем сэмплировать из этого скользящего окна вокруг слова, почему бы просто не использовать цикл и не использовать их все вместо сэмплирования? Кажется странным, что они беспокоятся о производительности/памяти в word2vec_basic.py (их "базовой" реализации).
  2. Каким бы ни был ответ на вопрос 1), почему они оба производят выборку и отслеживают то, что они выбрали с помощью targets_to_avoid? Если бы им нужен был действительно случайный выбор, они бы использовали выборку с заменой, а если бы они хотели убедиться, что получили все варианты, они должны были просто использовать цикл и получить их все в первую очередь!
  3. Использует ли встроенный tf.models.embedding.gen_word2vec тоже так работает? Если да, то где я могу найти исходный код? (не удалось найти файл .py в репозитории Github)

Спасибо!


person lollercoaster    schedule 29.06.2016    source источник
comment
ты нашел ответ? Если да, не могли бы вы добавить в качестве ответа?   -  person Aerin    schedule 09.07.2017


Ответы (2)


Я попробовал предложенный вами способ создания пакетов - с циклом и использованием всего окна пропуска. Результаты:

<сильный>1. Более быстрое создание пакетов

Для размера пакета 128 и окна пропуска 5

  • создание пакетов путем циклического перебора данных один за другим занимает 0,73 с на 10 000 пакетов
  • создание пакетов с помощью обучающего кода, и num_skips=2 занимает 3,59 с на 10 000 пакетов.

<сильный>2. Повышенная опасность переобучения

Оставив остальную часть учебного кода как есть, я обучил модель обоим способам и записал средние потери каждые 2000 шагов:

введите здесь описание изображения

Эта закономерность повторялась неоднократно. Он показывает, что использование 10 образцов на слово вместо 2 может привести к переобучению.

Вот код, который я использовал для создания пакетов. Он заменяет функцию generate_batch из учебника.

data_index = 0

def generate_batch(batch_size, skip_window):
    global data_index
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)  # Row
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)  # Column

    # For each word in the data, add the context to the batch and the word to the labels
    batch_index = 0
    while batch_index < batch_size:
        context = data[get_context_indices(data_index, skip_window)]

        # Add the context to the remaining batch space
        remaining_space = min(batch_size - batch_index, len(context))
        batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
        labels[batch_index:batch_index + remaining_space] = data[data_index]

        # Update the data_index and the batch_index
        batch_index += remaining_space
        data_index = (data_index + 1) % len(data)

    return batch, labels

Редактировать: get_context_indices — это простая функция, которая возвращает срез индекса в skip_window вокруг data_index. Дополнительные сведения см. в документации по функции slice().

person Kilian Batzner    schedule 13.12.2016
comment
что находится в get_context_indices - person Shan Khan; 20.06.2017
comment
Это не отвечает на вопрос. - person Aerin; 09.07.2017
comment
@ToussaintLouverture Это ответ на вопрос номер 1. :) Какой вопрос вы имеете в виду? - person zwep; 30.08.2017

Существует параметр с именем num_skips, который обозначает количество пар (вход, вывод), сгенерированных из одного окна: [skip_window target skip_window]. Таким образом, num_skips ограничьте количество слов контекста, которые мы будем использовать в качестве выходных слов. Именно поэтому функция generate_batch assert num_skips <= 2*skip_window. Код просто случайным образом выбирает num_skip слова контекста для создания обучающих пар с целью. Но я не знаю, как num_skips влияет на производительность.

person user1903382    schedule 23.07.2016