Изменение критерия факела

Я хочу создать пользовательскую функцию потерь в Torch, которая является модификацией ClassNLLCriterion. Конкретно, потеря ClassNLLCriterion:

loss(x, class) = -x[class]

Я хочу изменить это так:

loss(x, class) = -x[class]*K

где K является функцией входа сети, а НЕ весов сети или выхода сети. Таким образом, K можно рассматривать как константу.

Как проще всего реализовать этот пользовательский критерий? Функция updateOutput() кажется простой, но как изменить функцию updateGradInput()?


person braindead    schedule 02.06.2017    source источник


Ответы (1)


В основном ваша функция потерь L является функцией входа и цели. Так что у тебя есть

loss(input, target) = ClassNLLCriterion(input, target) * K

если я правильно понимаю вашу новую потерю. Затем вы хотите реализовать updateGradInput, который возвращает производную вашей функции потерь по отношению к входу, который равен

updateGradInput[ClassNLLCriterion](input, target) * K + ClassNLLCriterion(input, target) * dK/dinput

Поэтому вам нужно только вычислить производную от K по входным данным функции потерь (вы не дали нам формулу для вычисления K) и вставить ее в предыдущую строку. Поскольку ваша новая функция потерь зависит от ClassNLLCriterion, вы можете использовать updateGradInput и updateOutput этой функции потерь для расчета своей.

person fonfonx    schedule 02.06.2017
comment
Так что, по сути, мне не нужно писать собственный критерий. В моем обучающем коде я могу просто сделать: loss = ClassNLLCriterion:forward()*K, а затем grad = ClassNLLCriterion:backward()*K+loss*(dK/dinput) Это правильно? - person braindead; 02.06.2017
comment
Да это тоже возможно - person fonfonx; 02.06.2017
comment
Потрясающий. Спасибо! Еще один вопрос: если K — просто константа (не зависящая от параметров сети, ввода или вывода), как изменится ваш ответ в этом случае? - person braindead; 02.06.2017
comment
если K - просто константа, я не вижу смысла ее использовать ... Это просто умножит все значения потерь на одну и ту же константу. Таким образом, вы можете просто использовать ClassNLLCriterion - person fonfonx; 02.06.2017