Keras BatchNormalization работает только для постоянного затемнения партии, когда ось = 0?

В следующем коде показан один способ, который работает, а другой — нет.

BatchNorm на оси = 0 не должен зависеть от размера партии, а если и зависит, то это должно быть явно указано в документах.

In [118]: tf.__version__
Out[118]: '2.0.0-beta1'



class M(tf.keras.models.Model):
import numpy as np
import tensorflow as tf

class M(tf.keras.Model):

    def __init__(self, axis):
        super().__init__()
        self.layer = tf.keras.layers.BatchNormalization(axis=axis, scale=False, center=True, input_shape=(6,))

    def call(self, x):
        out = self.layer(x)
        return out

def fails():
    m = M(axis=0)
    x = np.random.randn(3, 6).astype(np.float32)
    print(m(x))
    x = np.random.randn(2, 6).astype(np.float32)
    print(m(x))

def ok():
    m = M(axis=1)
    x = np.random.randn(3, 6).astype(np.float32)
    print(m(x))
    x = np.random.randn(2, 6).astype(np.float32)
    print(m(x))

РЕДАКТИРОВАТЬ:

Ось в аргументах — это не та ось, о которой вы думаете.


person mathtick    schedule 25.07.2019    source источник
comment
stackoverflow.com/questions/47538391/ Возможно, ось в аргументах неправильная ось. Может быть, это не ось интеграции?   -  person mathtick    schedule 25.07.2019
comment
Прочитав ваш вопрос, я вообще не понимаю, в чем проблема, пожалуйста, будьте очень конкретны.   -  person Dr. Snoopy    schedule 25.07.2019
comment
Аргумент axis должен быть осью признаков. Во многих случаях это будет -1. Если у вас есть свертка «сначала каналы», это будет 1. И т. Д. Предоставление оси пакета в качестве аргумента здесь имеет очень мало смысла.   -  person xdurch0    schedule 25.07.2019
comment
Обратите внимание, что batchnorm выделяет параметры для каждого измерения axis, поэтому размер должен быть постоянным; это полностью ожидаемое поведение.   -  person xdurch0    schedule 25.07.2019
comment
В основном кажется, что ось - это не то, что думают естественно. Ось — это то, что вы не нормализуете.   -  person mathtick    schedule 25.07.2019


Ответы (1)


Как было указано в этом ответе и Keras doc, аргумент axis указывает на ось функции. Это полностью имеет смысл, потому что мы хотим выполнить нормализацию по функциям, то есть нормализовать каждую функцию по всему входному пакету (это в соответствии с нормализацией по функциям, которую мы можем выполнять для изображений, например, вычитая «средний пиксель» из всех изображений. набора данных).

Теперь написанный вами метод fails() терпит неудачу в этой строке:

x = np.random.randn(2, 6).astype(np.float32)
print(m(x))

Это потому, что вы установили ось объекта как 0, то есть первую ось, при построении модели и, следовательно, когда следующие строки выполняются перед приведенным выше кодом:

x = np.random.randn(3, 6).astype(np.float32)
print(m(x))

вес слоя будет построен на основе 3 объектов (не забывайте, что вы указали ось объекта как 0, поэтому во входных данных формы (3,6) будет 3 объекта). Поэтому, когда вы даете ему входной тензор формы (2,6), он правильно выдает ошибку, потому что в этом тензоре есть 2 функции, и поэтому нормализация не может быть выполнена из-за этого несоответствия.

С другой стороны, метод ok() работает, потому что ось признаков является последней осью, и поэтому оба входных тензора имеют одинаковое количество признаков, то есть 6. Таким образом, нормализация может быть выполнена в обоих случаях для всех признаков.

person today    schedule 25.07.2019