Как я могу бесконечно читать из Tensorflow tf.data.Dataset?

Я переключаю свой старый уровень данных (с использованием очередей) на «новый» и рекомендуемый API набора данных. Я использую его впервые, поэтому привожу примеры кода на случай, если у меня что-то в корне не так.

Я создаю свой набор данных из генератора (который будет читать файл и предоставлять n образцов). Это небольшой набор данных и n_iterations >> n_samples, поэтому я просто хочу читать этот набор данных снова и снова, в идеале в случайном порядке.

sample_set = tf.data.Dataset.from_generator( data_generator(filename),  
    (tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)

с генератором данных:

class data_generator:
    def __init__(self, filename):
        self.filename= filename

    def __call__(self):
        with filename.open() as f:
           for idx in f: yield img[idx], label[idx]

Чтобы использовать данные, я понял, что мне нужно определить Iterator

sample = sample_set.make_one_shot_iterator().get_next()

а затем мы настроены на чтение данных

while True:
    try: my_sample = sess.run(sample)
    except tf.errors.OutOfRangeError: break   # this happens after dset is read once

Но все доступные итераторы кажутся «конечными» в том смысле, что они читают набор данных только один раз.

Есть ли простой способ сделать чтение из набора данных бесконечным?


person Honeybear    schedule 19.03.2018    source источник


Ответы (3)


В наборах данных есть repeat и _ 2_.

BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),  
    (tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), 
    tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)
person dm0_    schedule 19.03.2018
comment
Идеально, именно то, что я имел в виду! Я выберу этот, так как перемешивание тоже действительно полезно. - person Honeybear; 19.03.2018

Преобразование Dataset.repeat() будет бесконечно повторять набор данных, если вы этого не сделаете. t передать ему явный count:

sample_set = tf.data.Dataset.from_generator(
    data_generator(filename), (tf.uint8, tf.uint8),
    (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))

# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()

sample = sample_set.make_one_shot_iterator().get_next()
person mrry    schedule 19.03.2018
comment
реализована ли функция dataset.repeat() таким образом, как мой ответ (т.е. примерно так: try: my_sample = sess.run(sample) except tf.errors.OutOfRangeError: sess.run(sample_set_init_op) # re-initialize on same dataset)? Потому что в моем журнале все еще отображается Out of range: ...-Errors, но он продолжает работать. - person Honeybear; 20.03.2018
comment
Он может печатать это каждый раз, когда исходный data_generator() достигает конца одного повторения. (Я считаю, что это дополнительное ведение журнала было удалено в более поздней версии TensorFlow.) - person mrry; 20.03.2018

Повторно инициализируемый итератор будет работать с повторной инициализацией одного и того же набора данных, поэтому этот код будет читать один и тот же набор данных снова и снова:

sample = tf.data.Iterator.from_structure(sample_set.output_types,
                                         sample_set.output_shapes).get_next()

sample_it.make_initializer(sample_set)     # create initialize op

with tf.Session(config=config) as sess:
    sess.run(sample_set_init_op)           # initialize in the beginning

    while True:
        try: 
             my_sample = sess.run(sample)
        except tf.errors.OutOfRangeError:
             sess.run(sample_set_init_op)  # re-initialize on same dataset
person Honeybear    schedule 19.03.2018
comment
Это может сработать для других, НО, помимо того, что вы чувствуете себя грязным, чтобы производить и ловить подобные ошибки, для этого требуется перезапуск графика. В этом примере с игрушкой граф является тривиальным, но обычно ввод данных является частью большого обучающего графа. Для моего основного цикла с управлением контрольными точками и множеством других вещей, включая output = sess.run(), я использую оболочку, предоставленную кем-то другим, поэтому в основном у меня нет доступа к sess.run(), и мне бы хотелось, чтобы модель обрабатывала бесконечные считываемые данные (а не требуя от меня перезапустить ошибочную модель). - person Honeybear; 19.03.2018