Сопоставьте каждое значение тензора с ближайшим значением в списке

У меня есть тензор A размера [batchSize,2,2,2], где batchSize — заполнитель. В пользовательском слое я хотел бы сопоставить каждое значение этого тензора с ближайшим значением в списке c с длиной n. Список - это моя кодовая книга, и я хотел бы квантовать каждое значение в тензоре на основе этой кодовой книги; то есть найдите ближайшее значение к каждому значению тензора в списке и замените значение тензора этим. Я не мог придумать «чистую» тензорную операцию, которая быстро сделает это. Я не могу перебрать batchSize. Есть ли способ сделать это в Tensorflow?


person deepsy    schedule 26.05.2019    source источник
comment
Можете ли вы привести пример вашей кодовой книги? Возможно, вы можете использовать https://www.tensorflow.org/api_docs/python/tf/quantization/quantize, если это стандартное минимальное-максимальное квантование. Или, если у вас есть несколько пар ключ-значение, вы можете сначала выполнить некоторую нормализацию, а затем выполнить поиск ключ/значение с помощью tf.contrib.lookup.HashTable.   -  person greeness    schedule 27.05.2019
comment
@greeness Спасибо за ваш ответ. tf.quantization.quantize у меня не работает, так как мои значения квантования неравномерны. Я думаю, что хеш-таблица мне тоже не подходит, так как я случайным образом беру значения тензора A из распределения Гаусса. Вектор кодовой книги c включает в себя неравномерно квантованные значения из распределения Гаусса с длиной 100. В результате я сопоставляю значения, взятые случайным образом из непрерывного распределения, с квантованными значениями.   -  person deepsy    schedule 27.05.2019


Ответы (1)


Если я правильно понимаю, это выполнимо с tf.HashTable. В качестве иллюстрации я использовал нормальное распределение с mean=0, stddev=4.

a = tf.random.normal(
  shape = [batch, 2, 2, 2],
  mean=0.0,
  stddev=4
)

И я использовал квантование всего с 5 сегментами (см. рисунок, отмеченный цифрами 0, 1, 2, 3, 4). Это расширяется до любой длины n. Примечание. Я намеренно сделал ведра переменной длины.

введите здесь описание изображения

Поэтому моя кодовая книга:

a <= -2          -> bucket 4
-2 < a < -0.5    -> bucket 3
-0.5 <= a < 0.5  -> bucket 0
0.5 <= a < 2.5   -> bucket 1
a >= 2.5         -> bucket 2

Идея состоит в том, чтобы предварительно создать сопоставление ключ/значение от масштабированного a до номера корзины. (количество пар <key,value> зависит от необходимой детализации ввода. Здесь я масштабировался на 10). Ниже приведен код для инициализации таблицы сопоставления и созданного сопоставления (ввод масштабируется на 10).

# The boundary is chosen based on that we clip by min=-4, max=4. 
# after scaling, the boundary becomes -40 and 40. 
keys = range(-40, 41)
values  = []
for k in keys:
  if k <= -20:
    values.append(4)
  elif k < -5:
    values.append(3)
  elif k < 5:
    values.append(0)
  elif k < 25:
    values.append(1)
  else:
    values.append(2)
for (k, v) in zip(keys, values):
  print ("%2d -> %2d" % (k, v))


-40 ->  4
-39 ->  4
...
-22 ->  4
-21 ->  4
-20 ->  4
-19 ->  3
-18 ->  3
...
-7 ->  3
-6 ->  3
-5 ->  0
-4 ->  0
...
 3 ->  0
 4 ->  0
 5 ->  1
 6 ->  1
 ...
23 ->  1
24 ->  1
25 ->  2
26 ->  2
...
40 ->  2
batch = 3
a = tf.random.normal(
    shape = [batch, 2, 2, 2],
    mean=0.0,
    stddev=4,
    dtype=tf.dtypes.float32
)
clip_a = tf.clip_by_value(a, clip_value_min=-4, clip_value_max=4)
SCALE = 10
scaled_clip_a = tf.cast(clip_a * SCALE, tf.int32)

table = tf.contrib.lookup.HashTable(
    tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)
quantized_a = tf.reshape(
    table.lookup(tf.reshape(scaled_clip_a, [-1])), 
    [batch, 2, 2, 2])

with tf.Session() as sess:
  table.init.run()
  a, clip_a, scaled_clip_a, quantized_a = sess.run([a, clip_a, scaled_clip_a, quantized_a])
  print ('a\n%s' % a)
  print ('clip_a\n%s' % clip_a)
  print ('scaled_clip_a\n%s' % scaled_clip_a)
  print ('quantized_a\n%s' % quantized_a)

Результат:

a
[[[[-0.26980758 -5.56331968]
   [ 5.04240322 -7.18292665]]

  [[-7.11545467 -3.24369478]
   [ 1.01861215 -0.04510783]]]


 [[[-0.28768024  0.2472897 ]
   [ 2.17780781 -5.79106379]]

  [[ 8.45582008  4.53902292]
   [ 0.138162   -6.19155598]]]


 [[[-7.5134449   4.56302166]
   [-0.30592337 -0.60313278]]

  [[-0.06204566  3.42917275]
   [-1.14547718  3.31167102]]]]
clip_a
[[[[-0.26980758 -4.        ]
   [ 4.         -4.        ]]

  [[-4.         -3.24369478]
   [ 1.01861215 -0.04510783]]]


 [[[-0.28768024  0.2472897 ]
   [ 2.17780781 -4.        ]]

  [[ 4.          4.        ]
   [ 0.138162   -4.        ]]]


 [[[-4.          4.        ]
   [-0.30592337 -0.60313278]]

  [[-0.06204566  3.42917275]
   [-1.14547718  3.31167102]]]]
scaled_clip_a
[[[[ -2 -40]
   [ 40 -40]]

  [[-40 -32]
   [ 10   0]]]


 [[[ -2   2]
   [ 21 -40]]

  [[ 40  40]
   [  1 -40]]]


 [[[-40  40]
   [ -3  -6]]

  [[  0  34]
   [-11  33]]]]
quantized_a
[[[[0 4]
   [2 4]]

  [[4 4]
   [1 0]]]


 [[[0 0]
   [1 4]]

  [[2 2]
   [0 4]]]


 [[[4 2]
   [0 3]]

  [[0 2]
   [3 2]]]]
person greeness    schedule 28.05.2019