PyTorch Backwards Automatic Diff. со взвешенным убытком

В настоящее время я пытаюсь реализовать метод, в котором я выполняю градиентный спуск с использованием взвешенных потерь, и мне было интересно, может ли кто-нибудь помочь мне с реализацией этого метода в pytorch, поскольку мой текущий метод ошибка. Псевдо-код для него приведен ниже.

введите описание изображения здесь

Моя реализация (без учета цикла for):

  model.train()
  ## Create a Dummy Model
  dummy = L2RWCNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, 
                DROPOUT, PAD_IDX)
  dummy.state_dict(model.state_dict())
  
  ## Get Training and Validation Data AKA Lines 2 & 3
  text_t, labels_t = next(iter(train_iterator))
  text_v, labels_v = next(iter(valid_iterator))

  y_f_hat = dummy(text_t) ## Line 4

  ## Line 5
  eps = Variable(torch.zeros(y_f_hat.size()), requires_grad=True)
  cost = F.binary_cross_entropy_with_logits(y_f_hat.squeeze(), labels_t, reduce = False)
  l_f_meta = torch.sum(eps * cost)
  dummy.zero_grad()
  #Line 6
  grads = torch.autograd.grad(l_f_meta, (dummy.params()), create_graph = True, retain_graph=True)
  #Line 7
  dummy.update_params(l_f_meta, source_params=grads)
  #Line 8
  y_g_hat = dummy(text_v)
  #Line 9
  l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat.squeeze(), labels_v, reduce = True)
  #line 10
  grad_eps = torch.autograd.grad(l_g_meta, eps)[0] ## THIS ERRORS OUT

В строке 10 происходит сбой кода из-за отсутствия eps в вычислительном графе. Мне было интересно, есть ли лучший способ сделать строку 7 в псевдокоде. В настоящее время я использую для этого специальную функцию, и я думаю, что это может быть проблемой, поскольку теоретически вычислительный граф должен иметь путь eps->l_f_meta->model parameters (aka theta) -> y_g - > l_g, что означает, что eps и l_g связаны между собой, поэтому теоретически должен быть градиент.

Мой код update_params:

def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
    if source_params is not None:
        for tgt, src in zip(self.named_params(self), source_params):
            name_t, param_t = tgt
            
            # name_s, param_s = src
            # grad = param_s.grad
            # name_s, param_s = src
            grad = src
            if first_order:
                grad = to_var(grad.detach().data)
            tmp = param_t - lr_inner * grad
            self.set_param(self, name_t, tmp)
    else:

        for name, param in self.named_params(self):
            if not detach:
                grad = param.grad
                if first_order:
                    grad = to_var(grad.detach().data)
                tmp = param - lr_inner * grad
                self.set_param(self, name, tmp)
            else:
                param = param.detach_()
                self.set_param(self, name, param)

def set_param(self,curr_mod, name, param):
    if '.' in name:
        n = name.split('.')
        module_name = n[0]
        rest = '.'.join(n[1:])
        for name, mod in curr_mod.named_children():
            if module_name == name:
                self.set_param(mod, rest, param)
                break
    else:
        param = nn.Parameter(param)
        setattr(curr_mod, name, param)

Если бы кто-нибудь мог дать мне лучший способ обновить параметры модели с взвешенными потерями, просто лучший способ реализовать весь мой код или решение проблемы eps не в вычислительном графе. Это было бы очень признательно, так как я был застрял на этом надолго.


person Nakul Upadhya    schedule 08.04.2021    source источник
comment
Это также потенциально могло быть из-за проблемы с setattr   -  person Nakul Upadhya    schedule 09.04.2021