Как обучить сеть только на одном выходе, когда их несколько?

Я использую модель с несколькими выходами в Keras

model1 = Model(input=x, output=[y2, y3])

model1.compile((optimizer='sgd', loss=cutom_loss_function)

моя custom_loss функция

def custom_loss(y_true, y_pred):
   y2_pred = y_pred[0]
   y2_true = y_true[0]

   loss = K.mean(K.square(y2_true - y2_pred), axis=-1)
   return loss

Я только хочу натренировать сеть на выходе y2.

Какова форма / структура аргументов y_pred и y_true в функции потерь при использовании нескольких выходов? Могу ли я получить к ним доступ, как указано выше? Это y_pred[0] или y_pred[:,0]?


person shaaa    schedule 25.05.2017    source источник


Ответы (3)


Я только хочу обучить сеть на выходе y2.

На основе руководства по функциональному API Keras вы можете добиться этого с помощью

model1 = Model(input=x, output=[y2,y3])   
model1.compile(optimizer='sgd', loss=custom_loss_function,
                  loss_weights=[1., 0.0])

Какова форма / структура аргументов y_pred и y_true в функции потерь при использовании нескольких выходов? Могу ли я получить к ним доступ, как указано выше? Это y_pred [0] или y_pred [:, 0]

В моделях keras с несколькими выходами функция потерь применяется для каждого выхода отдельно. В псевдокоде:

loss = sum( [ loss_function( output_true, output_pred ) for ( output_true, output_pred ) in zip( outputs_data, outputs_model ) ] )

Мне кажется, что функция потерь на нескольких выходах недоступна. Вероятно, этого можно было бы достичь, включив функцию потерь как слой сети.

person Sharapolas    schedule 09.06.2017
comment
In keras multi-output models loss function is applied for each output separately. У меня аналогичная проблема, и мне отдельно нужны значения y_true и y_pred для двух отдельных выходных данных. как я могу это решить? - person Eka; 21.01.2018
comment
Если структура не изменилась недавно, самое простое решение - объединить выходные данные в одну функцию потерь, а затем обработать их там. - person Sharapolas; 24.01.2018
comment
@Sharapolas У вас есть практический пример этого утверждения the easiest solution is to concatenate the outputs into a single loss function and then to handle them there? - person ihavenoidea; 19.11.2019

Ответ Шараполаса правильный.

Однако есть способ лучше, чем использование слоя для построения пользовательских функций потерь со сложной взаимозависимостью нескольких выходных данных модели.

Я знаю, что на практике используется метод никогда не вызывать model.compile, а только model._make_predict_function(). С этого момента вы можете создать собственный метод оптимизатора, вызвав там model.output. Это даст вам все выходные данные, в вашем случае [y2, y3]. Когда творите с ним свою магию, получите keras.optimizer и используйте его метод get_update, используя свой model.trainable_weights и свои потери. Наконец, верните keras.function со списком необходимых входных данных (в вашем случае только model.input) и обновлений, которые вы только что получили в результате вызова optimizer.get_update. Эта функция теперь заменяет model.fit.

Вышеупомянутое часто используется в алгоритмах PolicyGradient, таких как A3C или PPO. Вот пример того, что я пытался объяснить: https://github.com/Hyeokreal/Actor-Critic-Continuous-Keras/blob/master/a2c_continuous.py Посмотрите на методы build_model и crit_optimizer и прочтите документацию по kreas.backend.function, чтобы понять, что происходит.

Я обнаружил, что у этого способа часто возникают проблемы с управлением сеансом, и в настоящее время он, похоже, вообще не работает в tf-2.0 keras. Следовательно, если кто-нибудь знает метод, дайте мне знать. Я пришел сюда в поисках одного :)

person Nric    schedule 28.04.2019

Принятый ответ не будет работать в целом, если пользовательские потери не могут быть применены к выходным данным, которые вы пытаетесь игнорировать, например если у них неправильная форма. В этом случае вы можете назначить этим выходам фиктивную функцию потерь:

labels = [labels_for_relevant_output, dummy_labels_for_ignored_output]

def dummy_loss(y_true, y_pred):
    return 0.0

model.compile(loss = [custom_loss_function, dummy_loss])
model.fit(x, labels)
person Elan    schedule 10.02.2021