TransformerDecoder#
- class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[source]#
TransformerDecoder is a stack of N decoder layers.
Note
See this tutorial for an in depth discussion of the performant building blocks PyTorch offers for building your own transformer layers.
Warning
All layers in the TransformerDecoder are initialized with the same parameters. It is recommended to manually initialize the layers after creating the TransformerDecoder instance.
- Parameters
decoder_layer (TransformerDecoderLayer) – an instance of the TransformerDecoderLayer() class (required).
num_layers (int) – the number of sub-decoder-layers in the decoder (required).
norm (Optional[Module]) – the layer normalization component (optional).
Examples
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory)
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[source]#
Pass the inputs (and mask) through the decoder layer in turn.
- Parameters
tgt (Tensor) – the sequence to the decoder (required).
memory (Tensor) – the sequence from the last layer of the encoder (required).
tgt_mask (Optional[Tensor]) – the mask for the tgt sequence (optional).
memory_mask (Optional[Tensor]) – the mask for the memory sequence (optional).
tgt_key_padding_mask (Optional[Tensor]) – the mask for the tgt keys per batch (optional).
memory_key_padding_mask (Optional[Tensor]) – the mask for the memory keys per batch (optional).
tgt_is_causal (Optional[bool]) – If specified, applies a causal mask as
tgt mask
. Default:None
; try to detect a causal mask. Warning:tgt_is_causal
provides a hint thattgt_mask
is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.memory_is_causal (bool) – If specified, applies a causal mask as
memory mask
. Default:False
. Warning:memory_is_causal
provides a hint thatmemory_mask
is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
- Return type
- Shape:
see the docs in
Transformer
.