Пользовательская функция потерь: выполнить model.predict для данных в y_pred

Я обучаю сеть шумоподавлению изображений, для этого я использую набор данных CIFAR10. Я пытаюсь создать настраиваемую функцию потерь, чтобы потеря была mse / classification_accuracy. Учитывая, что моя сеть получает на входе 32x32 (зашумленные) изображения и предсказывает 32x32 (с шумоподавлением) изображения, я предполагаю, что y_pred и Y_true будут массивами 32x32 изображений. Таким образом, мои пользовательские функции потерь выглядят так:

def custom_loss():
    def joint_optimized_loss(y_true, y_pred):
        mse =  K.mean(K.square(y_pred - y_true), axis=-1)
        preds = classif_model.predict(y_pred)
        correctPreds = 0
        totPreds = 0
        for pred in preds:
            predictedClass = pred.index(max(pred))
            totPreds += 1
            if predictedClass == currentClass: 
                correctPreds += 1
        classifAccuracy = correctPreds / totPreds
        loss = mse / classifAccuracy
        return loss
    return joint_optimized_loss
myModel.compile(optimizer='adadelta', loss=custom_loss())

classif_model - это предварительно обученная модель, которая классифицирует изображения CIFAR10 в один из 10 классов. Он получает массив изображений 32x32.

Однако, когда я запускаю свой код, я получаю следующую ошибку:

Отслеживание (последний вызов последний):

Файл "myCode.py", строка 94, в

myModel.compile (optimizer = 'adadelta', loss = custom_loss ()) Файл "/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py", строка 850 , в компиляции

sample_weight, mask) Файл "/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py", строка 450, во взвешенном

score_array = fn (y_true, y_pred) Файл "myCode.py", строка 57, в Joint_optimized_loss

preds = classif_model.predict (y_pred) Файл "/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/models.py", строка 913, в прогнозе

return self.model.predict (x, batch_size = batch_size, verbose = verbose) Файл "/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py", строка 1713 г., в прогнозе

verbose = verbose, steps = steps) Файл "/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py", строка 1260, в _predict_loop

batches = _make_batches (num_samples, batch_size) Файл "/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py", строка 374, в _make_batches

num_batches = int (np.ceil (size / float (batch_size)))
AttributeError: объект 'Dimension' не имеет атрибута 'ceil'

Я думаю, это как-то связано с тем фактом, что y_true и y_pred являются тензорами, которые перед обучением пусты, поэтому classif_model.predict не работает, поскольку ожидает массив. Однако я не уверен, как это исправить ...

Я попытался получить вместо этого значение y_pred с помощью K.get_value(y_pred), но это привело к следующей ошибке:

tenorflow.python.framework.errors_impl.InvalidArgumentError: фигура [-1,32,32,3] имеет отрицательные размеры [[Node: input_1 = Placeholderdtype = DT_FLOAT, shape = [?, 32,32,3], _device = "/ задание: localhost / реплика: 0 / задача: 0 / процессор: 0 "]]


person rovim    schedule 30.12.2017    source источник


Ответы (2)


Вы не можете использовать точность как функцию потерь, поскольку она не дифференцируема. Вот почему вместо этого используются верхние границы точности, такие как кросс-энтропия.

Кроме того, способ, которым вы реализовали точность, также не является символическим, вам следовало использовать только функции в keras.backend, чтобы реализовать потерю, чтобы она работала правильно.

person Dr. Snoopy    schedule 30.12.2017

У меня была почти такая же проблема, и я попробовал это, и у меня это сработало.

Вместо того:

preds = classif_model.predict (y_pred)

пытаться:

preds = classif_model (y_pred)

Я не уверен в причине, но это потому, что, когда мы используем model.predict (y), ему нужен batch_size, а во время компиляции у нас его нет, поэтому мы не можем использовать model.predict (y). Пожалуйста, поправьте меня, если это не так.

person Narayan Kothari    schedule 25.05.2018