티스토리 뷰

728x90

 

Building an encoder-decoder transformer architecture for sequence-to-equence language tasks like text translation and summarization

 

  • Encoder-Decoder Connection: The encoder connects to the decoder through cross-Attention, allowing the decoder to use the encoder's final hidden states to generate the target sequence.
  • Cross-Attention Mechanism: This mechanism helps the decoder "look back" at the input sequence to generate the next word in the target sequence. For example, in translating "I really like to travel" to Spanish, "travel" receives the highest attention.
  • Decoder Layer: The forward() method in the decoder layer requires two masks: the causal mask for the first attention stage and the cross-attention mask for the second stage.
  • Training vs. Inference: During training, the decoder uses actual target sequences as inputs. During inference, it generates the target sequence starting with an empty output embedding.

 

# Create a batch of random input sequences
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))
padding_mask = torch.randint(0, 2, (sequence_length, sequence_length))
causal_mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1)

# Instantiate the two transformer bodies
encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)

# Pass the necessary masks as arguments to the encoder and the decoder
encoder_output = encoder(input_sequence, padding_mask)
decoder_output = decoder(input_sequence, causal_mask, encoder_output, padding_mask)
print("Batch's output shape: ", decoder_output.shape)
728x90
댓글