pytorch concat dataset code example

Example: concat dataset

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

train_loader = torch.utils.data.DataLoader(
             ConcatDataset( # concat
                 datasets.ImageFolder(traindir_A),
                 datasets.ImageFolder(traindir_B)
             ),
             batch_size=args.batch_size, shuffle=True,
             num_workers=args.workers, pin_memory=True)

for i, (input, target) in enumerate(train_loader):
    ...