iterabledataset shuffle code example
Example: iterabledataset shuffle
class ShuffleDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset, buffer_size):
super().__init__()
self.dataset = dataset
self.buffer_size = buffer_size
def __iter__(self):
shufbuf = []
try:
dataset_iter = iter(self.dataset)
for i in range(self.buffer_size):
shufbuf.append(next(dataset_iter))
except:
self.buffer_size = len(shufbuf)
try:
while True:
try:
item = next(dataset_iter)
evict_idx = random.randint(0, self.buffer_size - 1)
yield shufbuf[evict_idx]
shufbuf[evict_idx] = item
except StopIteration:
break
while len(shufbuf) > 0:
yield shufbuf.pop()
except GeneratorExit:
pass