How do I display a single image in PyTorch?
As you can see matplotlib
works fine even without conversion to numpy
array. But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib
you need to reshape it:
Code:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
Output:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
Given the image is loaded as described and stored in the variable image
:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively
Or as Soumith suggested:
def show(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
Given a Tensor
representing the image, use .permute()
to put the channels as the last dimension:
plt.imshow( tensor_image.permute(1, 2, 0) )
Note: permute
does not copy or allocate memory, and from_numpy()
doesn't either.