How to iterate over two dataloaders simultaneously using pytorch?

If you want to iterate over two datasets simultaneously, there is no need to define your own dataset class just use TensorDataset like below:

dataset = torch.utils.data.TensorDataset(dataset1, dataset2)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
for index, (xb1, xb2) in enumerate(dataloader):
    ....

If you want the labels or iterating over more than two datasets just feed them as an argument to the TensorDataset after dataset2.


Further to what it is already mentioned, cycle() and zip() might create a memory leakage problem - especially when using image datasets! To solve that, instead of iterating like this:

dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):

    for i, (data1, data2) in enumerate(zip(cycle(dataloaders1), dataloaders2)):
        
        do_cool_things()

you could use:

dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    dataloader_iterator = iter(dataloaders1)
    
    for i, data1 in enumerate(dataloaders2)):

        try:
            data2 = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloaders1)
            data2 = next(dataloader_iterator)

        do_cool_things()

Bear in mind that if you use labels as well, you should replace in this example data1 with (inputs1,targets1) and data2 with inputs2,targets2, as @Sajad Norouzi said.

KUDOS to this one: https://github.com/pytorch/pytorch/issues/1917#issuecomment-433698337