Как определить пользовательскую функцию потерь keras с помощью простой математической операции

Я определяю пользовательскую функцию my_sigmoid следующим образом:

import math
def my_sigmoid(x):
    a =  1/  ( 1+math.exp( -(x-300)/30 ) )
    return a

А затем определите пользовательскую функцию потерь с именем my_cross_entropy:

import keras.backend as K

def my_cross_entropy(y_true, y_pred):
    diff = abs(y_true-y_pred)
    y_pred_transform = my_sigmoid(diff)
    return K.categorical_crossentropy(0, y_pred_transform)

Мой бэкэнд keras использует тензорный поток. И ошибка показывает

TypeError: должно быть действительное число, а не тензор

Я не знаком с тензорным потоком и не знаю, как использовать настраиваемую потерю.

Ниже приведены структура моей модели и сообщение об ошибке:

import keras.backend as K
from keras.models import Sequential
from keras.layers import Conv2D, Dropout, Flatten, Dense

model=Sequential()
model.add(Conv2D(512,(5,X_train.shape[2]),input_shape=X_train.shape[1:4],activation="relu"))
model.add(Flatten())
model.add(Dropout(0.1))
model.add(Dense(100,activation="relu"))
model.add(Dense(100,activation="relu"))
model.add(Dense(50,activation="relu"))
model.add(Dense(10,activation="relu"))
model.add(Dense(1,activation="relu"))
model.compile(optimizer='adam', loss=my_cross_entropy)
model.fit(X_train,Y_train,batch_size = 10,epochs=200,validation_data=(X_test,Y_test))

введите описание изображения здесь

Форма X_train и Y_train: (120, 30, 80, 1) и (120,)


person Jim Chen    schedule 09.10.2018    source источник
comment
Можете ли вы опубликовать полный небольшой пример?   -  person George    schedule 09.10.2018
comment
Я не знаю, как размещать свои данные на SO ...........   -  person Jim Chen    schedule 09.10.2018
comment
Не ваши данные, опубликуйте код, который вы используете, и точку, где ошибка.   -  person George    schedule 09.10.2018


Ответы (1)


Изменять

diff = abs(y_true-y_pred)

в

diff = K.abs(y_true-y_pred)

то же самое для

math.exp()

измените это на

K.exp()

abs и Math.exp - это функции, которые не могут обрабатывать тензоры. Если у вас все еще есть проблемы, обратитесь к: Пользовательская функция потерь Keras Tensorflow

person Mete Han Kahraman    schedule 09.10.2018