How to convert RGB images to grayscale in PyTorch dataloader?

When using ImageFolder class and with no custom loader, pytorch uses PIL to load image and converts it to RGB. Default Loader if torchvision image backend is PIL:

def pil_loader(path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB')

You can use torchvision's Grayscale function in transforms. It will convert the 3 channel RGB image into 1 channel grayscale. Find out more about this at https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Grayscale

A sample code is below,

import torchvision as tv
import numpy as np
import torch.utils.data as data
dataDir         = 'D:\\general\\ML_DL\\datasets\\CIFAR'
trainTransform  = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1),
                                    tv.transforms.ToTensor(), 
                                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet        = tv.datasets.CIFAR10(dataDir, train=True, download=False, transform=trainTransform)
dataloader      = data.DataLoader(trainSet, batch_size=1, shuffle=False, num_workers=0)
images, labels  = iter(dataloader).next()
print (images.size())

Tags:

Python

Pytorch