tf.reduce_max несовместим с сравнением

Учитывая список / тензор элементов, я хочу проверить, совпадает ли максимальный элемент всего списка с максимальным элементом в определенной части списка:

import tensorflow as tf
a = tf.get_variable('a', (10,100))
b = tf.unstack(a,axis=1)
c = tf.reduce_max(b[0])
d = tf.reduce_max(b[0])
if c == d:
  c = tf.ones((1,100))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([c,d])

В приведенном выше примере c и d должны быть одинаковыми, однако, когда вы запускаете модель, она не удовлетворяет условию повторного преобразования переменной c как одного вектора. Это просто пример, показывающий, что подобные сравнения кажутся неправильными. Есть идеи, как правильно это сделать?


person Yaser Kenesh    schedule 17.05.2018    source источник


Ответы (1)


в приведенном выше примере c и d должны быть одинаковыми

Нет, здесь тебе следует быть осторожным. c и d - разные операции в вычислительном графе. Сравнивать их с == бессмысленно, это всегда разные объекты. На самом деле вам нужно tf.equal, чтобы определить, соответствуют ли значения тензора тот же элементарный и tf.cond для организации оператора if в вычислительном графе . Это выглядело бы примерно так:

result = tf.cond(tf.equal(c, d), lambda: tf.ones((1, 100)), lambda: tf.zeros((1, 100)))

Также обратите внимание, что переназначение переменной python, которая указывает на операцию (c в данном случае) не изменяет операцию на графике.

person Maxim    schedule 17.05.2018
comment
Чтобы следить за Максимом. [Здесь] [1] - поток stackoverflow о том, как создавать условия в тензорном потоке. [1]: stackoverflow.com/questions/35833011/ - person Herman Wilén; 17.05.2018