Q1.
Я пытаюсь сделать свою пользовательскую функцию autograd с помощью pytorch.
Но у меня возникла проблема с аналитическим обратным распространением с y = x/sum(x, dim=0)
где размер тензора x равен (высота, ширина) (x двумерный).
Вот мой код
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input = input / torch.sum(input, dim=0)
return input
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
H, W = input.size()
sum = torch.sum(input, dim=0)
grad_input = grad_output * (1/sum - input*1/sum**2)
return grad_input
Я использовал (torch.autograd import) gradcheck для сравнения матрицы Якоби,
from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)
и результат был
Пожалуйста, помогите мне получить правильный результат обратного распространения
Спасибо!
Q2.
Спасибо за ответы!
Благодаря вашей помощи я смог реализовать обратное распространение в случае тензора (H,W).
Однако, пока я реализовал обратное распространение в случае тензора (N,H,W), у меня возникла проблема. Я думаю, проблема будет в инициализации нового тензора.
Вот мой новый код
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
N = input.size(0)
for n in range(N):
input[n] /= torch.sum(input[n], dim=0)
return input
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
N, H, W = input.size()
I = torch.eye(H).unsqueeze(-1)
sum = input.sum(1)
grad_input = torch.zeros((N,H,W), dtype = torch.double, requires_grad=True)
for n in range(N):
grad_input[n] = ((sum[n] * I - input[n]) * grad_output[n] / sum[n]**2).sum(1)
return grad_input
Код Gradcheck
from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.rand(2,2,2,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)
print(test)
и результат: введите здесь описание изображения
Я не знаю, почему возникает ошибка...
Ваша помощь будет мне очень полезна для реализации моей собственной сверточной сети.
Спасибо! Хорошего дня.