API набора данных Tensorflow: распараллеливание tf.data.Dataset.from_generator с parallel_interleave

В производственной среде у меня есть данные, поступающие от N производителей, которые должны пройти через сеть. Я нашел этот комментарий к распараллеливанию tf.data.Dataset.from_generator, который действительно описывает что я хочу.

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use

Однако как должна выглядеть функция генератора (n). Потому что, когда я запускаю этот пример с

 def generator(n):
        """Returns the n-th generator function (for consumer n)
        """
        consumer = self.consumers[n]

        def gen():
            for item in consumer:
                yield item

        return gen

с self.consumers списком Python я получу ошибку:

TypeError: индексы списка должны быть целыми числами или срезами, а не тензором


person Derk    schedule 11.05.2018    source источник
comment
Что такое n? Судя по коду, который вы разместили, это похоже на tf.Tensor, который нельзя использовать для индексации в списке.   -  person xdurch0    schedule 11.05.2018
comment
Да, именно поэтому я и задал этот вопрос :)   -  person Derk    schedule 11.05.2018


Ответы (1)


Реализация почти правильная, но вы получаете сообщение об ошибке, поскольку аргумент n в dataset(n) является "символическим" tf.Tensor, а не фактическим значением, которое можно использовать для поиска потребителя в self.consumers.

К счастью, есть обходной путь, который включает передачу n через необязательный аргумент args в tf.data.Dataset.from_generator():

def dataset(n):
  return tf.data.Dataset.from_generator(generator, args=(n,))

Под прикрытием from_generator() вставляет некоторый код для преобразования n в целое число Python перед каждым вызовом generator.

person mrry    schedule 23.04.2019