Правильный способ остановить набор данных TensorFlow «from_generator»?

Я хотел бы использовать набор данных TensorFlow, созданный с помощью from_generator, для доступа к отформатированному файлу. Почти все работает, за исключением того, что я не знаю, как остановить итератор набора данных, когда у генератора заканчиваются данные (генератор просто возвращает пустые списки навсегда, когда вы выходите за пределы диапазона).

Мой фактический код очень сложен, но я могу смоделировать ситуацию с помощью этой короткой программы:

import tensorflow as tf

def make_batch_generator_fn(batch_size=10, dset_size=100):
    feats, targs = range(dset_size), range(1, dset_size + 1)

    def batch_generator_fn():
        start_idx, stop_idx = 0, batch_size
        while True:
            # if stop_idx > dset_size: --- stop action?
            yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
            start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

    return batch_generator_fn

def test(batch_size=10):
    dgen = make_batch_generator_fn(batch_size)
    features_shape, targets_shape = [None], [None]
    ds = tf.data.Dataset.from_generator(
        dgen, (tf.int32, tf.int32),
        (tf.TensorShape(features_shape), tf.TensorShape(targets_shape))
    )
    feats, targs = ds.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        counter = 0
        try:
            while True:
                f, t = sess.run([feats, targs])
                print(f, t)
                counter += 1
                if counter > 15:
                    break
        except tf.errors.OutOfRangeError:
            print('end of dataset at counter = {}'.format(counter))

if __name__ == '__main__':
    test()

Если я заранее знаю количество записей, я могу настроить количество пакетов, но я не всегда это знаю. Я попытался поместить некоторый код во фрагмент выше, где у меня есть строка комментария, например stop action?. В частности, я пытался поднять IndexError, но TensorFlow это не нравится, даже если я явно catch в своем коде выполнения. Я также пытался создать tf.errors.OutOfRangeError, но я не уверен, как его создать: конструктор требует три аргумента — «node_def», «op» и «message», и я не совсем уверен, что использовать для «node_def». ' и вообще 'оп'.

Буду признателен за любые мысли или комментарии по этому вопросу. Спасибо!


person Gabriel Perdue    schedule 10.05.2018    source источник


Ответы (2)


Вернитесь, когда вы соответствуете критериям остановки:

def make_batch_generator_fn(batch_size=10, dset_size=100):
    feats, targs = range(dset_size), range(1, dset_size + 1)

    def batch_generator_fn():
        start_idx, stop_idx = 0, batch_size
        while True:
            if stop_idx > dset_size:
                return
            else:
                yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
                start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

    return batch_generator_fn

Это соответствует поведению, указанному в документации Python 3:

В функции-генераторе оператор return указывает, что работа генератора завершена, и вызовет вызов StopIteration. Возвращенное значение (если есть) используется в качестве аргумента для построения StopIteration и становится атрибутом StopIteration.value.

person Nathan    schedule 10.05.2018

Он работает со следующими строками:

dataset_size = your dataset size
batch_size = your batch size
dataset = your tf.data.Dataset
steps_per_epoch = dataset_size // batch_size

for data, _ in zip(dataset, range(steps_per_epoch)):
    # your train_step

Итерация остановится, когда она будет завершена.

person Pengcheng Fan    schedule 20.11.2020
comment
@ M-Chen-3, привет, я обновил свой код, думаю, он сам себя объясняет :) - person Pengcheng Fan; 22.11.2020