Как я могу узнать, какое имя входного узла или слоя для слоя в PyTorch? Скажем, если у меня есть torch.cat, как я могу узнать имя тензора или слоя, откуда он получает входные данные?
Для этого кода из https://rosenfelder.ai/multi-input-neural-network-pytorch/
class LitClassifier(pl.LightningModule):
def __init__(
self, lr: float = 1e-3, num_workers: int = 4, batch_size: int = 32,
):
super().__init__()
self.lr = lr
self.num_workers = num_workers
self.batch_size = batch_size
self.conv1 = conv_block(3, 16)
self.conv2 = conv_block(16, 32)
self.conv3 = conv_block(32, 64)
self.ln1 = nn.Linear(64 * 26 * 26, 16)
self.relu = nn.ReLU()
self.batchnorm = nn.BatchNorm1d(16)
self.dropout = nn.Dropout2d(0.5)
self.ln2 = nn.Linear(16, 5)
self.ln4 = nn.Linear(5, 10)
self.ln5 = nn.Linear(10, 10)
self.ln6 = nn.Linear(10, 5)
self.ln7 = nn.Linear(10, 1)
def forward(self, img, tab):
img = self.conv1(img)
img = self.conv2(img)
img = self.conv3(img)
img = img.reshape(img.shape[0], -1)
img = self.ln1(img)
img = self.relu(img)
img = self.batchnorm(img)
img = self.dropout(img)
img = self.ln2(img)
img = self.relu(img)
tab = self.ln4(tab)
tab = self.relu(tab)
tab = self.ln5(tab)
tab = self.relu(tab)
tab = self.ln6(tab)
tab = self.relu(tab)
x = torch.cat((img, tab), dim=1)
x = self.relu(x)
return self.ln7(x)
Итак, если я хочу знать, с какого слоя torch.cat получает входные данные.
Для keras у нас есть model.get_layer(id=idx).input.name
, есть ли что-то подобное и для PyTorch?
img
содержит то, что ей было присвоено последним, то есть последнюю строкуimg = self.relu(img)
. Аналогично с переменнойtab
. Я не понимаю вопроса, это просто базовый Python - person Alexey Larionov   schedule 17.06.2021forward
, например, в вашем примере один слойself.relu
вызывается 6 раз с разными входными данными. Существует агрегацияnn.Sequential
слоев, которая в основном реализует передачу некоторогоx
на первый уровень, затем вывод этого уровня на второй уровень и так один для всех слоев. Чтобы узнать, где находится слой, вы можете работать с индексамиnn.Sequential
внутренних слоев. - person Alexey Larionov   schedule 17.06.2021can we extract that data from the model object as whole?
Обычно вы сохраняете некоторый результат, который хотите, как поле вашей модели, например. вforward()
сохраните некоторый промежуточный вывод, подобный этомуself.last_tab = tab
, затем получите к нему доступ какmodel.last_tab
где угодно - person Alexey Larionov   schedule 17.06.2021nn.Sequential
, то простойprint(model)
может дать то, что вы ищете, как там. - person Alexey Larionov   schedule 17.06.2021