Я пытаюсь запустить (стохастический или пакетный) градиентный спуск, когда используется стандартная перекрестная энтропия (softmax loss):
при использовании в качестве модели сети Radial Basis Function (RBF) (вы можете посмотреть форму лекции caltech здесь, если хотите) при расширении до мультиклассовой классификации (легко расширяется путем простого ввода вывод сети RBF на уровень softmax. Обратите внимание, что P(y=l|x)
просто вычисляется путем передачи вывода сети RBF через уровень softmax для каждой метки l
следующим образом:
где \theta_l
индексирует параметры, отвечающие за выполнение прогнозов для метки l
.
В связи с этим я хотел оптимизировать свою модель, вычисляя производные по параметрам. Напомним, что параметрами для оптимизации в сети радиальных базисных функций являются веса c
на последнем слое и центры t
на первом слое. Я реализовал и отладил, как вычислить производную по весам c
. Код работает так, как ожидалось, потому что частные производные соответствуют числовым производным. Вы можете найти код модульного теста здесь a >.
Я также пробовал написать код, реализующий производную по отношению к центрам, но я просто не могу заставить мою реализацию производной соответствовать числовым производным. Уравнение производной потерь J
относительно центров t_k
, которое я пытаюсь реализовать, выглядит следующим образом:
где h_{\theta_l}
соответствует выходу RBF, который отвечает за прогнозирование метки l
. На самом деле h_{\theta_l}
очень просто выразить:
Моя основная проблема связана с вычислением производной J
по t_k
(уравнение выше). Для этого я реализовал следующую функцию a > который наивно вычисляет его без векторизации:
function [ dJ_dt ] = compute_dJ_dt(z,x,y,t,c)
%Computes dJ_dc
% Input:
% z = (K x 1)
% x = data point (D, 1)
% y = labels (1 x 1)
% t = centers (D x K)
% c = weights (K x L)
% Output:
% dJ_dc = (D x K)
[D,K] = size(t);
[~, L] = size(c);
dJ_dt = zeros(D, K);
for k=1:K
dJ_dt_k = zeros(D, 1);
for l=1:L
c_l = c(:,l);
dh_dt_l = compute_dh_dt(z,x,t,c_l); %(D x K)
delta = (y==l);
dJ_dt_k = dJ_dt_k + dh_dt_l(:,k) * delta;
end
dJ_dt(:,k) = -dJ_dt_k;
end
end
и он не соответствует числовому коду производных >.
Я пробовал разные вещи, чтобы проверить, работает ли это, и я все объясню здесь. Если у кого-то есть дополнительные идеи, не стесняйтесь поделиться ими, я вроде как чувствую, что у меня закончились хорошие новые идеи, чтобы попытаться отладить это.
- Во-первых, естественный вопрос: верен ли мой математический вывод производной, которую я пытаюсь реализовать? Несмотря на то, что я явно не проверял математический вывод с кем-то, я очень уверен, что он правильный, потому что вывод для частной производной по
c
иt
в модели идентичен, и вы только меняете символ\theta
на любой параметр, который у вас есть обсуждаемый. Поскольку я уже реализовал производную по отношению кc
и она проходит все мои производные тесты, я предполагаю, что производная по отношению кt
или любому параметру\theta
должна быть правильной. Мой вывод этого уравнения можно увидеть в math.stack exchange здесь. - Один из вариантов может заключаться в том, что
compute_dJ_dt
на самом деле не реализует уравнение, которого я ожидаю. Это действительно могло быть так, и чтобы убедиться, что я независимо реализовал немного более векторизованная версия этого кода, чтобы увидеть, действительно ли я реализую уравнение, которое у меня было на бумаге. Поскольку две версии уравнения выводят одни и те же производные значения, я уверен, что они вычисляют, действительно, уравнение, которое я подозреваю (также, если у кого-то есть способ дальнейшей векторизации этого уравнения, это было бы потрясающе! настолько тривиален, что не кажется таким уж интересным или большим приростом производительности, хотя он удаляет один цикл for).
Поскольку уравнение, которое у меня есть на бумаге, является (с большой вероятностью) правильным, и реализация уравнения кажется правильной, поскольку две его версии выводят одно и то же значение, это приводит меня к выводу, что, возможно, код числовой производной имеет ошибку .
- числовой производный код настолько смехотворно прост что трудно проверить, что же, черт возьми, с этим может быть не так. Единственное, что мне пришло в голову, что могло быть неправильно, это то, что моя реализация softmax cost J неверен, но я очень сомневаюсь в этом, так как ... Я уже написал для него модульный тест! Кроме того, я использую его для проверки числовых производных относительно
c
и тех, которые дляc
ВСЕГДА проходят, поэтому я не могу представить, чтоJ
ошибается. - Последняя нетривиальная вещь, которую нужно проверить, - это то, что
compute_dh_dt
вычисляется правильно. Я написал тесты модулей для dh_dt и поскольку они соответствуют своим соответствующим числовым производным при каждом запуске, я подозреваю, что код правильный.
На данный момент я не уверен на 100%, что еще попробовать. Я надеюсь, что, может быть, у кого-то есть хорошая идея или, может быть, укажет на мою глупость? Я не знаю, что думать прямо сейчас. Спасибо за помощь и время сообществу!