Pytorch reshape tensor dimension
There are multiple ways of reshaping a PyTorch tensor. You can apply these methods on a tensor of any dimensionality.
Let's start with a 2-dimensional 2 x 3
tensor:
x = torch.Tensor(2, 3)
print(x.shape)
# torch.Size([2, 3])
To add some robustness to this problem, let's reshape the 2 x 3
tensor by adding a new dimension at the front and another dimension in the middle, producing a 1 x 2 x 1 x 3
tensor.
Approach 1: add dimension with None
Use NumPy-style insertion of None
(aka np.newaxis
) to add dimensions anywhere you want. See here.
print(x.shape)
# torch.Size([2, 3])
y = x[None, :, None, :] # Add new dimensions at positions 0 and 2.
print(y.shape)
# torch.Size([1, 2, 1, 3])
Approach 2: unsqueeze
Use torch.Tensor.unsqueeze(i)
(a.k.a. torch.unsqueeze(tensor, i)
or the in-place version unsqueeze_()
) to add a new dimension at the i'th dimension. The returned tensor shares the same data as the original tensor. In this example, we can use unqueeze()
twice to add the two new dimensions.
print(x.shape)
# torch.Size([2, 3])
# Use unsqueeze twice.
y = x.unsqueeze(0) # Add new dimension at position 0
print(y.shape)
# torch.Size([1, 2, 3])
y = y.unsqueeze(2) # Add new dimension at position 2
print(y.shape)
# torch.Size([1, 2, 1, 3])
In practice with PyTorch, adding an extra dimension for the batch may be important, so you may often see unsqueeze(0)
.
Approach 3: view
Use torch.Tensor.view(*shape)
to specify all the dimensions. The returned tensor shares the same data as the original tensor.
print(x.shape)
# torch.Size([2, 3])
y = x.view(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])
Approach 4: reshape
Use torch.Tensor.reshape(*shape)
(aka torch.reshape(tensor, shapetuple)
) to specify all the dimensions. If the original data is contiguous and has the same stride, the returned tensor will be a view of input (sharing the same data), otherwise it will be a copy. This function is similar to the NumPy reshape()
function in that it lets you define all the dimensions and can return either a view or a copy.
print(x.shape)
# torch.Size([2, 3])
y = x.reshape(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])
Furthermore, from the O'Reilly 2019 book Programming PyTorch for Deep Learning, the author writes:
Now you might wonder what the difference is between view()
and reshape()
. The answer is that view()
operates as a view on the original tensor, so if the underlying data is changed, the view will change too (and vice versa). However, view()
can throw errors if the required view is not contiguous; that is, it doesn’t share the same block of memory it would occupy if a new tensor of the required shape was created from scratch. If this happens, you have to call tensor.contiguous()
before you can use view()
. However, reshape()
does all that behind the scenes, so in general, I recommend using reshape()
rather than view()
.
Approach 5: resize_
Use the in-place function torch.Tensor.resize_(*sizes)
to modify the original tensor. The documentation states:
WARNING. This is a low-level method. The storage is reinterpreted as C-contiguous, ignoring the current strides (unless the target size equals the current size, in which case the tensor is left unchanged). For most purposes, you will instead want to use view()
, which checks for contiguity, or reshape()
, which copies data if needed. To change the size in-place with custom strides, see set_()
.
print(x.shape)
# torch.Size([2, 3])
x.resize_(1, 2, 1, 3)
print(x.shape)
# torch.Size([1, 2, 1, 3])
My observations
If you want to add just one dimension (e.g. to add a 0th dimension for the batch), then use unsqueeze(0)
. If you want to totally change the dimensionality, use reshape()
.
See also:
What's the difference between reshape and view in pytorch?
What is the difference between view() and unsqueeze()?
In PyTorch 0.4, is it recommended to use reshape
than view
when it is possible?
you might use
a.view(1,5)
Out:
1 2 3 4 5
[torch.FloatTensor of size 1x5]
Use torch.unsqueeze(input, dim, out=None)
:
>>> import torch
>>> a = torch.Tensor([1, 2, 3, 4, 5])
>>> a
1
2
3
4
5
[torch.FloatTensor of size 5]
>>> a = a.unsqueeze(0)
>>> a
1 2 3 4 5
[torch.FloatTensor of size 1x5]
For in-place modification of the shape of the tensor, you should use
tensor.resize_()
:
In [23]: a = torch.Tensor([1, 2, 3, 4, 5])
In [24]: a.shape
Out[24]: torch.Size([5])
# tensor.resize_((`new_shape`))
In [25]: a.resize_((1,5))
Out[25]:
1 2 3 4 5
[torch.FloatTensor of size 1x5]
In [26]: a.shape
Out[26]: torch.Size([1, 5])
In PyTorch, if there's an underscore at the end of an operation (like tensor.resize_()
) then that operation does in-place
modification to the original tensor.
Also, you can simply use np.newaxis
in a torch Tensor to increase the dimension. Here is an example:
In [34]: list_ = range(5)
In [35]: a = torch.Tensor(list_)
In [36]: a.shape
Out[36]: torch.Size([5])
In [37]: new_a = a[np.newaxis, :]
In [38]: new_a.shape
Out[38]: torch.Size([1, 5])