Why do we do batch matrix-matrix product?

In the seq2seq model, the encoder encodes the input sequences given in as mini-batches. Say for example, the input is B x S x d where B is the batch size, S is the maximum sequence length and d is the word embedding dimension. Then the encoder's output is B x S x h where h is the hidden state size of the encoder (which is an RNN).

Now while decoding (during training) the input sequences are given one at a time, so the input is B x 1 x d and the decoder produces a tensor of shape B x 1 x h. Now to compute the context vector, we need to compare this decoder hidden state with the encoder's encoded states.

So, consider you have two tensors of shape T1 = B x S x h and T2 = B x 1 x h. So if you can do batch matrix multiplication as follows.

out = torch.bmm(T1, T2.transpose(1, 2))

Essentially you are multiplying a tensor of shape B x S x h with a tensor of shape B x h x 1 and it will result in B x S x 1 which is the attention weight for each batch.

Here, the attention weights B x S x 1 represent a similarity score between the decoder's current hidden state and encoder's all the hidden states. Now you can take the attention weights to multiply with the encoder's hidden state B x S x h by transposing first and it will result in a tensor of shape B x h x 1. And if you perform squeeze at dim=2, you will get a tensor of shape B x h which is your context vector.

This context vector (B x h) is usually concatenated to decoder's hidden state (B x 1 x h, squeeze dim=1) to predict the next token.


The operations depicted in the above figure happens on the Decoder side of the Seq2Seq model. Meaning that encoder outputs are already in terms of batches (with mini-batch size samples). Consequently, attn_weights tensor should also be in batch mode.

Thus, in essence, the first dimension (zeroth axis in NumPy terminology) of the tensors attn_weights and encoder_outputs is the number of samples of mini-batch size. Thus, we need torch.bmm on these two tensors.


while @wasiahmad is right about the general implementation of seq2seq, in the mentioned tutorial there's no batch (B=1), and the bmm is just over-engineering and can be safely replaced with matmul with the exact same model quality and performance. See for yourself, replace this:

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)

with this:

        attn_applied = torch.matmul(attn_weights,
                                 encoder_outputs)
        output = torch.cat((embedded[0], attn_applied), 1)

and run the notebook.


Also, note that while @wasiahmad talks about the encoder input as B x S x d, in pytorch 1.7.0, the GRU which is the main engine of the encoder expects an input format of (seq_len, batch, input_size) by default. If you want to work with @wasiahmad format, pass the batch_first = True flag.