This function is designed to merge a batch of data examples into a format suitable for further processing.
Exported source
def collate(batch_list):"""This function is designed to merge a batch of data examples into a format suitable for further processing.""" example_merged = defaultdict(list)for example in batch_list:for k, v in example.items(): example_merged[k].append(v) ret = {}for key, elems in example_merged.items():if key =="token": ret[key] = elemselif'point'in key: coors = []for i, coor inenumerate(elems): coor_pad = np.pad( coor, ((0, 0), (1, 0)), mode="constant", constant_values=i ) coors.append(coor_pad) ret[key] = torch.tensor(np.concatenate(coors, axis=0))elifisinstance(elems[0], list): ret[key] = defaultdict(list) res = []for elem in elems:for idx, ele inenumerate(elem): ret[key][str(idx)].append(torch.tensor(ele))for kk, vv in ret[key].items(): res.append(torch.stack(vv)) ret[key] = reselse: ret[key] = torch.tensor(np.stack(elems, axis=0)).float()return ret
# Sample batch list of examplesbatch_list = [ {"token": [1, 2, 3],"point1": np.array([[1.0, 2.0], [3.0, 4.0]]),"point2": np.array([[5.0, 6.0]]),"nested_list": [[1, 2], [3, 4]],"value": np.array([1.0, 2.0]) }, {"token": [4, 5, 6],"point1": np.array([[7.0, 8.0]]),"point2": np.array([[9.0, 10.0], [11.0, 12.0]]),"nested_list": [[5, 6], [7, 8]],"value": np.array([3.0, 4.0]) }]# Using the collate functioncollated_batch = collate(batch_list)# Display the collated resultfor key, value in collated_batch.items():print(f"{key}: {value}")
The build_dataloader function is a utility for creating a PyTorch DataLoader with added support for distributed training. Here’s a breakdown of what the function does:
Distributed Training Support:
The function first checks if distributed training is initialized using dist.is_initialized(), if distributed training is active, it retrieves the rank and world size of the current process using dist.get_rank() and dist.get_world_size().
It then creates a DistributedSampler, which ensures that each process gets a different subset of the dataset. This sampler is used to handle data loading in a distributed manner.
If distributed training is not initialized, it defaults to using no sampler.
Creating the DataLoader:
The function creates a DataLoader using the provided dataset, batch size, number of workers, shuffle, and pin memory options.
It uses the sampler if one was created; otherwise, it shuffles the data if shuffle is set to True.
Parameters Abstracted from PyTorch Direct Implementation
The function abstracts away the following details from a direct PyTorch DataLoader implementation: - DistributedSampler: Automatically handles creating and using a DistributedSampler when distributed training is initialized. - Sampler Management: Abstracts the logic for deciding when to use a sampler and whether to shuffle the data. - Collate Function: Assumes a specific collate_fn (collate) is used, simplifying the DataLoader creation by not requiring the user to specify it.
Limitations
Fixed Collate Function: The function uses a predefined collate_fn. If a different collate function is needed, the user must manually modify the function.
Limited Customization: The function only exposes a subset of possible DataLoader parameters (batch size, number of workers, shuffle, and pin memory). For more advanced customization, the user might need to modify the function or revert to directly creating a DataLoader. PyTorch DataLoader supports advanced features such as persistent_workers, worker_init_fn, and timeout. The function does not expose these features, limiting its flexibility for more complex use cases.
Distributed Training Dependency: The function relies on PyTorch’s distributed package (torch.distributed) to determine if distributed training is initialized. If used in a non-distributed context without the appropriate setup, the distributed checks and sampler creation might add unnecessary complexity.
Further Enhancements
Some potential enhancements to the function include:
Custom Collate Function: Allow users to specify a custom collate_fn for more flexibility in data processing.
Expose Advanced DataLoader Parameters: Provide additional parameters for more advanced DataLoader configurations using **kwargs.
This function is designed to build a DataLoader object for a given dataset with optional distributed training support.
Type
Default
Details
dataset
Dataset object
batch_size
int
4
Batch size
num_workers
int
8
Number of workers
shuffle
bool
False
Shuffle the data
pin_memory
bool
False
Pin memory
Exported source
def build_dataloader(dataset, # Dataset object batch_size=4, # Batch size num_workers=8, # Number of workers shuffle:bool=False, # Shuffle the data pin_memory=False# Pin memory ): # A PyTorch DataLoader instance with the specified configuration."""This function is designed to build a DataLoader object for a given dataset with optional distributed training support."""if dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)else: sampler =None data_loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, shuffle=(sampler isNoneand shuffle), num_workers=num_workers, collate_fn=collate, pin_memory=pin_memory, )return data_loader