collatefn
collatefn is a user-supplied function used by the PyTorch DataLoader to merge a list of samples into a single batch. It is passed to the DataLoader via the collate_fn parameter. If collate_fn is not provided, PyTorch uses a default collate function (default_collate) that attempts to stack tensors and convert numpy arrays into tensors. The collate function operates on the list of items returned by the dataset’s __getitem__ method for a given batch size.
The primary purpose of collate_fn is to control how individual samples are combined into a batched structure.
- Stacking fixed-size tensors into a batch tensor (the default behavior for uniform data).
- Creating a dictionary or tuple batch from per-sample fields.
- Padding variable-length sequences to a common length within the batch.
- Modifying targets, such as converting a list of labels to a tensor.
inputs = [item['input'] for item in batch]
targets = [item['target'] for item in batch]
return {'input': torch.stack(inputs), 'target': torch.tensor(targets)}
- Collate functions run in worker processes when using multiple workers, so they should be pure and
- Ensure consistent output shapes and types across the batch.
- A custom collate_fn is often necessary for variable-length data, complex structures, or specialized padding schemes.