Pytorch: обратное распространение от суммы матричных элементов к листовой переменной

Я пытаюсь немного лучше понять обратное распространение в pytorch. У меня есть фрагмент кода, который успешно выполняет обратное распространение от выхода d к листовой переменной a, но затем, если я добавлю шаг изменения формы, обратное распространение больше не даст входному градиенту.

Я знаю, что изменение формы неуместно, но я все еще не уверен, как контекстуализировать это.

Есть предположения?

Спасибо.

#Works
a = torch.tensor([1.])
a.requires_grad = True
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

print('a gradient is')
print(a.grad) #=> Tensor([1.])

#Doesn't work
a = torch.tensor([1.])
a.requires_grad = True
a = a.reshape(a.shape)
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

print('a gradient is')
print(a.grad) #=> None

person user49404    schedule 01.05.2019    source источник
comment
что вы имеете в виду под контекстуализацией?   -  person Wasi Ahmad    schedule 02.05.2019
comment
Просто хотел получить более широкую картину того, что здесь происходит не так, как предоставил Сергей. Благодарю.   -  person user49404    schedule 03.05.2019


Ответы (1)


Редактировать:

Вот подробное объяснение того, что происходит («это не ошибка как таковая, но определенно вызывает путаницу»): https://github.com/pytorch/pytorch/issues/19778

Таким образом, одно из решений - специально попросить сохранить grad сейчас без листа a:

a = torch.tensor([1.])
a.requires_grad = True
a = a.reshape(a.shape)
a.retain_grad()
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

Старый ответ:

Если вы переместите a.requires_grad = True после изменения формы, это сработает:

a = torch.tensor([1.])
a = a.reshape(a.shape)
a.requires_grad = True
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

Похоже на ошибку в PyTorch, потому что после этого a.requires_grad все еще верно.

a = torch.tensor([1.])
a.requires_grad = True
a = a.reshape(a.shape)

Похоже, это связано с тем, что a больше не является листом в вашем примере «Не работает», но все еще листом в других случаях (выведите a.is_leaf, чтобы проверить).

person Sergii Dymchenko    schedule 01.05.2019