Bidirectional LSTM output question in PyTorch
Yes, when using a BiLSTM the hidden states of the directions are just concatenated (the second part after the middle is the hidden state for feeding in the reversed sequence).
So splitting up in the middle works just fine.
As reshaping works from the right to the left dimensions you won't have any problems in separating the two directions.
Here is a small example:
# so these are your original hidden states for each direction
# in this case hidden size is 5, but this works for any size
direction_one_out = torch.tensor(range(5))
direction_two_out = torch.tensor(list(reversed(range(5))))
print('Direction one:')
print(direction_one_out)
print('Direction two:')
print(direction_two_out)
# before outputting they will be concatinated
# I'm adding here batch dimension and sequence length, in this case seq length is 1
hidden = torch.cat((direction_one_out, direction_two_out), dim=0).view(1, 1, -1)
print('\nYour hidden output:')
print(hidden, hidden.shape)
# trivial case, reshaping for one hidden state
hidden_reshaped = hidden.view(1, 1, 2, -1)
print('\nReshaped:')
print(hidden_reshaped, hidden_reshaped.shape)
# This works as well for abitrary sequence lengths as you can see here
# I've set sequence length here to 5, but this will work for any other value as well
print('\nThis also works for more multiple hidden states in a tensor:')
multi_hidden = hidden.expand(5, 1, 10)
print(multi_hidden, multi_hidden.shape)
print('Directions can be split up just like this:')
multi_hidden = multi_hidden.view(5, 1, 2, 5)
print(multi_hidden, multi_hidden.shape)
Output:
Direction one:
tensor([0, 1, 2, 3, 4])
Direction two:
tensor([4, 3, 2, 1, 0])
Your hidden output:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([1, 1, 10])
Reshaped:
tensor([[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]]]) torch.Size([1, 1, 2, 5])
This also works for more multiple hidden states in a tensor:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([5, 1, 10])
Directions can be split up just like this:
tensor([[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]]]) torch.Size([5, 1, 2, 5])
Hope this helps! :)
I know output[2, 0] will give me a 200-dim vector. Does this 200 dim vector represent the output of 3rd input at both directions?
The answer is YES.
The output
tensor of LSTM module output is the concatenation of forward LSTM output and backward LSTM output at corresponding postion in input sequence.
And h_n
tensor is the output at last timestamp which is output of the lsat token in forward LSTM but the first token in backward LSTM.
In [1]: import torch
...: lstm = torch.nn.LSTM(input_size=5, hidden_size=3, bidirectional=True)
...: seq_len, batch, input_size, num_directions = 3, 1, 5, 2
...: in_data = torch.randint(10, (seq_len, batch, input_size)).float()
...: output, (h_n, c_n) = lstm(in_data)
...:
In [2]: # output of shape (seq_len, batch, num_directions * hidden_size)
...:
...: print(output)
...:
tensor([[[ 0.0379, 0.0169, 0.2539, 0.2547, 0.0456, -0.1274]],
[[ 0.7753, 0.0862, -0.0001, 0.3897, 0.0688, -0.0002]],
[[ 0.7120, 0.2965, -0.3405, 0.0946, 0.0360, -0.0519]]],
grad_fn=<CatBackward>)
In [3]: # h_n of shape (num_layers * num_directions, batch, hidden_size)
...:
...: print(h_n)
...:
tensor([[[ 0.7120, 0.2965, -0.3405]],
[[ 0.2547, 0.0456, -0.1274]]], grad_fn=<ViewBackward>)
In [4]: output = output.view(seq_len, batch, num_directions, lstm.hidden_size)
...: print(output[-1, 0, 0]) # forward LSTM output of last token
...: print(output[0, 0, 1]) # backward LSTM output of first token
...:
tensor([ 0.7120, 0.2965, -0.3405], grad_fn=<SelectBackward>)
tensor([ 0.2547, 0.0456, -0.1274], grad_fn=<SelectBackward>)
In [5]: h_n = h_n.view(lstm.num_layers, num_directions, batch, lstm.hidden_size)
...: print(h_n[0, 0, 0]) # h_n of forward LSTM
...: print(h_n[0, 1, 0]) # h_n of backward LSTM
...:
tensor([ 0.7120, 0.2965, -0.3405], grad_fn=<SelectBackward>)
tensor([ 0.2547, 0.0456, -0.1274], grad_fn=<SelectBackward>)