In pytorch, DataLoader will split a dataset into batches of set size with additional options of shuffling etc, which one can then loop over.

But if I need the batch size to increment, such as first 10 batch of size 50, next 5 batch of size 100 and so on, what's the best way of doing so?

I tried splitting the tensor then concat them:

#10x50 + 5*100
originalTensor = torch.randn(1000, 80)
split1=torch.split(originalTensor, 500, dim=0)
split2=torch.split(list(split1)[0], 100, dim=0)

Thereafter is there a way to pass the concatenated tensor into dataLoader or any other way to directly turn the concat tensor into a generator (which might lose shuffling and other functionalities)?

1 Answer 1


I think you can do that by simply providing a non-default batch_sampler to your DataLoader.
For instance:

class VaryingSizeBatchSampler(Sampler):
    r"""Wraps another sampler to yield a varying-size mini-batch of indices.

        sampler (Sampler): Base sampler.
        batch_size_fn (function): Size of current mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    def __init__(self, sampler, batch_size_fn, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
        self.sampler = sampler
        self.batch_size_fn = batch_size_fn
        self.drop_last = drop_last
        self.batch_counter = 0

    def __iter__(self):
        batch = []
        cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size
        for idx in self.sampler:
            if len(batch) == cur_batch_size:
                yield batch
                self.batch_counter += 1
                cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size                
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch            

    def __len__(self):
        raise NotImplementedError('You need to implement it yourself!')
  • 1
    that's really helpful! thanks for providing substantial details
    – santoku
    Commented Sep 22, 2019 at 16:53

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.