Замена tf.placeholder и feed_dict на tf.data API

У меня есть существующая модель TensorFlow, которая использовала tf.placeholder для ввода модели и параметр feed_dict tf.Session (). Run для подачи данных. Раньше весь набор данных считывался в память и передавался таким образом.

Я хочу использовать гораздо больший набор данных и воспользоваться преимуществами повышения производительности API tf.data. Я определил tf.data.TextLineDataset и одноразовый итератор из него, но мне трудно понять, как ввести данные в модель для ее обучения.

Сначала я попытался просто определить feed_dict как словарь от заполнителя до iterator.get_next (), но это дало мне ошибку, в которой говорилось, что значение канала не может быть объектом tf.Tensor. Дальнейшие поиски привели меня к пониманию того, что это связано с тем, что объект, возвращаемый функцией iterator.get_next (), уже является частью графика, в отличие от того, что вы вводите в feed_dict, и что мне вообще не следует пытаться использовать feed_dict для причины производительности.

Итак, теперь я избавился от входного tf.placeholder и заменил его параметром конструктора класса, определяющего мою модель; при построении модели в моем обучающем коде я передаю результат работы iterator.get_next () этому параметру. Это уже кажется немного неуклюжим, потому что нарушает разделение между определением модели и наборами данных / процедурой обучения. И теперь я получаю сообщение об ошибке, в котором говорится, что тензор, представляющий (я считаю) ввод моей модели, должен быть из того же графика, что и тензор из iterator.get_next ().

На правильном ли я пути с этим подходом и просто делаю что-то не так с тем, как я настраиваю график и сеанс, или что-то в этом роде? (Наборы данных и модель инициализируются вне сеанса, и ошибка возникает до того, как я попытаюсь создать один.)

Или я совершенно не согласен с этим и мне нужно сделать что-то другое, например, использовать API-интерфейс оценщика и определить все во входной функции?

Вот код, демонстрирующий минимальный пример:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])

person erobertc    schedule 10.04.2018    source источник
comment
Не могли бы вы опубликовать образец кода? Это, вероятно, поможет понять, что происходит не так.   -  person KRish    schedule 11.04.2018


Ответы (2)


Мне тоже потребовалось немного времени, чтобы сообразить. Вы на правильном пути. Полное определение набора данных - это лишь часть графика. Обычно я создаю его как класс, отличный от моего класса Model, и передаю набор данных в класс Model. Я указываю класс набора данных, который хочу загрузить, в командной строке, а затем загружаю этот класс динамически, тем самым модульно разделяя набор данных и график.

Обратите внимание, что вы можете (и должны) назвать все тензоры в наборе данных, это действительно помогает упростить понимание, когда вы передаете данные через различные преобразования, которые вам понадобятся.

Вы можете написать простые тестовые примеры, которые извлекают образцы из iterator.get_next() и отображают их, у вас будет что-то вроде sess.run(next_element_tensor), а не feed_dict, как вы правильно заметили.

Как только вы разберетесь с этим, вам, вероятно, понравится конвейер ввода набора данных. Это заставляет вас хорошо модулировать ваш код и заставляет его превращать в структуру, которая легко поддается модульному тестированию.

Обязательно прочтите руководство для разработчиков, там множество примеров:

https://www.tensorflow.org/programmers_guide/datasets

Еще я отмечу, насколько легко работать с обучающим и тестовым набором данных с помощью этого конвейера. Это важно, потому что вы часто выполняете увеличение данных в наборе обучающих данных, которое вы не выполняете в тестовом наборе данных, from_string_handle позволяет вам это сделать и четко описано в руководстве выше.

person David Parks    schedule 10.04.2018
comment
Спасибо, что сообщили мне, что я на правильном пути! Оказалось, что это просто строка tf.reset_default_graph () в конструкторе модели, которая у меня была. - person erobertc; 12.04.2018

Строка tf.reset_default_graph() в конструкторе модели из исходного кода, который мне был предоставлен, вызвала это. Удаление этого исправило.

person erobertc    schedule 11.04.2018