Реализация масок прямой функции преобразователя Pytorch для функции пересылки декодера

Я пытаюсь использовать и изучить PyTorch Transformer с набором математических данных DeepMind. У меня есть токенизированная (символ, а не слово) последовательность, которая вводится в модель. Функция пересылки моделей выполняет однократную пересылку для кодировщика и многократную пересылку для декодера (пока все пакетные выходы не достигнут токена, это все еще TODO). Я борюсь с масками трансформатора и декодером вперед, поскольку он выдает ошибку:

    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    RuntimeError: shape '[-1, 24, 64]' is invalid for input of size 819200.

Источник: N = 32, S = 50, E = 512. Целевой: N = 32, S = 3, E = 512. Возможно, у меня неправильная реализация масок или что исходная и целевая длины отличаются, не совсем уверен .

class PositionalEncoding(nn.Module):   
# function to positionally encode src and target sequencies 
def __init__(self, d_model, dropout=0.1, max_len=5000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)

def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

class MyTransformerModel(nn.Module):
# should implement init and forward function
# define separate functions for masks
# define forward function with
# implement:
#  embedding layer
#  positional encoding
#  encoder layer
#  decoder layer
#  final classification layer
# encoder -> forward once
# decoder -> forward multiple times (for one encoder forward)
# decoder output => concatenate to input e.g. decoder_input = torch.cat([decoder_input], [decoder_output])
# early stopping => all in batch reach <eos> token
def __init__(self, vocab_length = 30, sequence_length = 512, num_encoder_layers = 3, num_decoder_layers = 2, num_hidden_dimension = 256, feed_forward_dimensions = 1024, attention_heads = 8, dropout = 0.1, pad_idx = 3, device = "CPU", batch_size = 32):
    super(MyTransformerModel, self).__init__()
    self.src_embedding = nn.Embedding(vocab_length, sequence_length)
    self.pos_encoder = PositionalEncoding(sequence_length, dropout)
    self.src_mask = None # attention mask
    self.memory_mask = None # attention mask
    self.pad_idx = pad_idx        
    self.device = device        
    self.batch_size = batch_size
    self.transformer = nn.Transformer(
        sequence_length,
        attention_heads,
        num_encoder_layers,
        num_decoder_layers,
        feed_forward_dimensions,
        dropout,
    )
    
def src_att_mask(self, src_len):
    mask = (torch.triu(torch.ones(src_len, src_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def no_peak_att_mask(self, batch_size, src_len, time_step):
    mask = np.zeros((batch_size, src_len), dtype=bool)
    mask[:, time_step: ] = 1 # np.NINF
    mask = torch.from_numpy(mask)
    return mask

def make_src_key_padding_mask(self, src):
    # mask "<pad>"
    src_mask = src.transpose(0, 1) == self.pad_idx
    return src_mask.to(self.device)

def make_trg_key_padding_mask(self, trg):
    tgt_mask = trg.transpose(0, 1) == self.pad_idx
    return tgt_mask.to(self.device)


def forward(self, src, trg):
    src_seq_length, N = src.shape
    trg_seq_length, N = trg.shape
    embed_src = self.src_embedding(src)
    position_embed_src =  self.pos_encoder(embed_src)
    embed_trg = self.src_embedding(trg)
    position_embed_trg = self.pos_encoder(embed_trg)        
    src_padding_mask = self.make_src_key_padding_mask(src)
    trg_padding_mask = self.make_trg_key_padding_mask(trg)
    trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(self.device)
    time_step = 1
    att_mask = self.no_peak_att_mask(self.batch_size, src_seq_length, time_step).to(self.device)
    encoder_output = self.transformer.encoder.forward(position_embed_src, src_key_padding_mask = src_padding_mask)
    # TODO : implement loop for transformer decoder forward fn, implement early stopping
    # where to feed decoder_output?
    decoder_output = self.transformer.decoder.forward(position_embed_trg, encoder_output, trg_mask, att_mask, trg_padding_mask, src_padding_mask)
    return decoder_output
    

Может ли кто-нибудь указать, где я сделал ошибку?


person Roman Dulak    schedule 06.01.2021    source источник
comment
Не могли бы вы подробнее рассказать о том, что означает S и почему 50 в исходном и 3 в целевом? Мне кажется подозрительным, что 50 не делится на 3 без остатка, что также может быть причиной жалобы на то, что shape '[-1, 24, 64]' is invalid for input of size 819200, поскольку 819200 не делится без остатка на (24 * 64).   -  person Matthew Cox    schedule 06.01.2021


Ответы (1)


Похоже, я испортил порядок размеров (так как у Transformer нет опции первой партии). Исправленный код ниже:

class MyTransformerModel(nn.Module):
def __init__(self, d_model = 512, vocab_length = 30, sequence_length = 512, num_encoder_layers = 3, num_decoder_layers = 2, num_hidden_dimension = 256, feed_forward_dimensions = 1024, attention_heads = 8, dropout = 0.1, pad_idx = 3, device = "CPU", batch_size = 32):
    #, ninp, device, nhead=8, nhid=2048, nlayers=2, dropout=0.1, src_pad_idx = 1, max_len=5000, forward_expansion= 4):
    super(MyTransformerModel, self).__init__()
    self.src_embedding = nn.Embedding(vocab_length, d_model)
    self.pos_encoder = PositionalEncoding(d_model, dropout)
    self.vocab_length = vocab_length
    self.d_model = d_model
    self.src_mask = None # attention mask
    self.memory_mask = None # attention mask
    self.pad_idx = pad_idx        
    self.device = device        
    self.batch_size = batch_size
    self.transformer = nn.Transformer(
        d_model,
        attention_heads,
        num_encoder_layers,
        num_decoder_layers,
        feed_forward_dimensions,
        dropout,
    )

    self.fc = nn.Linear(d_model, vocab_length)
    # self.init_weights() <= used in tutorial

def src_att_mask(self, src_len):
    mask = (torch.triu(torch.ones(src_len, src_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def no_peak_att_mask(self, batch_size, src_len, time_step):
    mask = np.zeros((batch_size, src_len), dtype=bool)
    mask[:, time_step: ] = 1 # np.NINF
    mask = torch.from_numpy(mask)
    # mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def make_src_key_padding_mask(self, src):
    # mask "<pad>"
    src_mask = src.transpose(0, 1) == self.pad_idx
    # src_mask = src == self.pad_idx
    # (N, src_len)
    return src_mask.to(self.device)

def make_trg_key_padding_mask(self, trg):
    # same as above -> expected tgt_key_padding_mask: (N, T)
    tgt_mask = trg.transpose(0, 1) == self.pad_idx
    # tgt_mask = trg == self.pad_idx
    # (N, src_len)
    return tgt_mask.to(self.device)


def init_weights(self):
    initrange = 0.1
    nn.init.uniform_(self.encoder.weight, -initrange, initrange)
    nn.init.zeros_(self.decoder.weight)
    nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, src, trg):
    N, src_seq_length = src.shape
    N, trg_seq_length = trg.shape        
    #  S - source sequence length
    #  T - target sequence length
    #  N - batch size
    #  E - feature number
    #  src: (S, N, E) (sourceLen, batch, features)
    #  tgt: (T, N, E)
    #  src_mask: (S, S)
    #  tgt_mask: (T, T)
    #  memory_mask: (T, S)
    #  src_key_padding_mask: (N, S)
    #  tgt_key_padding_mask: (N, T)
    #  memory_key_padding_mask: (N, S)
    src = rearrange(src, 'n s -> s n')
    trg = rearrange(trg, 'n t -> t n')
    print("src shape {}".format(src.shape))
    print(src)
    print("trg shape {}".format(trg.shape))
    print(trg)

    embed_src = self.src_embedding(src)
    print("embed_src shape {}".format(embed_src.shape))
    print(embed_src)
    position_embed_src =  self.pos_encoder(embed_src)
    print("position_embed_src shape {}".format(position_embed_src.shape))
    print(position_embed_src)
    embed_trg = self.src_embedding(trg)
    print("embed_trg shape {}".format(embed_trg.shape))
    print(embed_trg)
    position_embed_trg = self.pos_encoder(embed_trg)
    # position_embed_trg = position_embed_trg.transpose(0, 1)
    print("position_embed_trg shape {}".format(position_embed_trg.shape))
    print(position_embed_trg)
    src_padding_mask = self.make_src_key_padding_mask(src)
    print("KEY - src_padding_mask shape {}".format(src_padding_mask.shape))
    print("should be of shape: src_key_padding_mask: (N, S)")
    print(src_padding_mask)
    trg_padding_mask = self.make_trg_key_padding_mask(trg)
    print("KEY - trg_padding_mask shape {}".format(trg_padding_mask.shape))
    print("should be of shape: trg_key_padding_mask: (N, T)")
    print(trg_padding_mask)
    trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(self.device)
    print("trg_mask shape {}".format(trg_mask.shape))
    print("trg_mask should be of shape tgt_mask: (T, T)")
    print(trg_mask)
    # att_mask = self.src_att_mask(trg_seq_length).to(self.device)
    time_step = 1
    # error => memory_mask: expected shape! (T, S) !!! this is not a key_padding_mask!
    # att_mask = self.no_peak_att_mask(self.batch_size, src_seq_length, time_step).to(self.device)
    # print("att_mask shape {}".format(att_mask.shape))
    # print("att_mask should be of shape  memory_mask: (T, S)")
    # print(att_mask)
    att_mask = None
    # get encoder output
    # forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None)
    # forward encoder just once for a batch
    # attention forward of encoder expects => src, src_mask, src_key_padding_mask +++ possible positional encoding error !!!
    encoder_output = self.transformer.encoder.forward(position_embed_src, src_key_padding_mask = src_padding_mask)
    print("encoder_output")  
    print("encoder_output shape {}".format(encoder_output.shape))
    print(encoder_output)  
    # forward decoder till all in batch did not reach <eos>?
    # def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
    # memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
    # memory_key_padding_mask: Optional[Tensor] = None)
    # first forward
    decoder_output = self.transformer.decoder.forward(position_embed_trg, encoder_output, trg_mask, att_mask, trg_padding_mask, src_padding_mask)
    # TODO: target in => target out shifted by one, loop till all in batch meet stopping criteria || max len is reached
    # 
    print("decoder_output")  
    print("decoder_output shape {}".format(decoder_output.shape))
    print(decoder_output)
    
    output = rearrange(decoder_output, 't n e -> n t e')
    output =  self.fc(output)
    print("output")  
    print("output shape {}".format(output.shape))
    print(output)

    predicted = F.log_softmax(output, dim=-1)
    print("predicted")  
    print("predicted shape {}".format(predicted.shape))
    print(predicted)
    # top k
    top_value, top_index = torch.topk(predicted, k=1)
    top_index = torch.squeeze(top_index)
    print("top_index")  
    print("top_index shape {}".format(top_index.shape))
    print(top_index)
    print("top_value")  
    print("top_value shape {}".format(top_value.shape))
    print(top_value)
    return top_index
person Roman Dulak    schedule 07.01.2021