Как использовать прошлое с HuggingFace Transformers GPT-2?

У меня есть:

        context = torch.tensor(context, dtype=torch.long, device=self.device)
        context = context.unsqueeze(0)
        generated = context
        with torch.no_grad():
            past_outputs = None
            for i in trange(num_words):
                print(i, num_words)
                inputs = {"input_ids": generated}

                outputs, past_outputs = self.model(
                    **inputs,
                    past=past_outputs
                )
                next_token_logits = outputs[
                    0, -1, :] / (temperature if temperature > 0 else 1.0)

                # reptition penalty from CTRL
                # (https://arxiv.org/abs/1909.05858)
                for _ in set(generated.view(-1).tolist()):
                    next_token_logits[_] /= repetition_penalty

                filtered_logits = top_k_top_p_filtering(
                    next_token_logits, top_k=top_k, top_p=top_p)
                if temperature == 0:  # greedy sampling:
                    next_token = torch.argmax(filtered_logits).unsqueeze(0)
                else:
                    next_token = torch.multinomial(
                        F.softmax(filtered_logits, dim=-1), num_samples=1)

                generated = torch.cat(
                    (generated, next_token.unsqueeze(0)), dim=1)

Это работает для первой итерации, но затем я получаю сообщение об ошибке на следующей итерации:

  File "/Users/shamoon/Sites/wordblot/packages/ml-server/generator.py", line 143, in sample_sequence
    past=past_outputs
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 601, in forward
    output_hidden_states=output_hidden_states,
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/transformers/modeling_gpt2.py", line 470, in forward
    position_embeds = self.wpe(position_ids)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/Users/shamoon/.local/share/virtualenvs/ml-server-EdimT5-E/lib/python3.7/site-packages/torch/nn/functional.py", line 1724, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

Что-то я делаю не так?


person Shamoon    schedule 03.08.2020    source источник
comment
Какая строка вызывает исключение? Можете ли вы получить более расширенную обратную связь?   -  person roman    schedule 03.08.2020
comment
Что такое model, generated, temperature? Этот ответ объясняет использование прошлого. Выложите, пожалуйста, полную трассировку стека. Я предполагаю, что вы превышаете максимальную длину ввода 1024.   -  person cronoik    schedule 03.08.2020
comment
model - это gpt2-xl, generated обновляется в коде. temperature равно 0,5   -  person Shamoon    schedule 03.08.2020
comment
Не могли бы вы включить полную трассировку стека? Какое значение имеет "число_слов"? Каков первоначальный размер context?   -  person cronoik    schedule 03.08.2020
comment
И я обновил полную трассировку стека.   -  person Shamoon    schedule 03.08.2020
comment
Какой класс вы использовали для загрузки своей модели? gpt2lmheadmodel?   -  person cronoik    schedule 07.08.2020
comment
Ага - вот тот.   -  person Shamoon    schedule 07.08.2020


Ответы (2)


Я считаю, что проблема в том, что context содержит целые числа, превышающие размер словарного запаса. Мое предположение основано на последней строке трассировки:

return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
person roman    schedule 04.08.2020
comment
@Shamoon, что ты имеешь в виду? - person roman; 04.08.2020
comment
Я бы также предположил, что next_token может быть вне словаря - person roman; 04.08.2020
comment
Если я не пройду past=past_outputs, то все будет нормально. - person Shamoon; 04.08.2020
comment
@Shamoon, тогда вы проверяли значение past_outputs? - person roman; 05.08.2020

Я сделал:

                outputs, past_outputs = self.models[model_name](
                    context,
                    past=past_outputs
                )
                context = next_token.unsqueeze(0)

person Shamoon    schedule 14.08.2020
comment
Не теряете ли вы таким образом исходный контекст? - person roman; 15.08.2020
comment
Ты хранишь прошлое, так что все в порядке. Думаю? - person Shamoon; 15.08.2020