build loader

Module to load the data from the dataset

source

collate

 collate (batch_list)

This function is designed to merge a batch of data examples into a format suitable for further processing.

Exported source
def collate(batch_list):
    """This function is designed to merge a batch of data examples into a format suitable for further processing."""
    example_merged = defaultdict(list)
    for example in batch_list:
        for k, v in example.items():
            example_merged[k].append(v)
    ret = {}
    for key, elems in example_merged.items():
        if key == "token":
            ret[key] = elems
        elif 'point' in key:
            coors = []
            for i, coor in enumerate(elems):
                coor_pad = np.pad(
                    coor, ((0, 0), (1, 0)), mode="constant", constant_values=i
                )
                coors.append(coor_pad)
            ret[key] = torch.tensor(np.concatenate(coors, axis=0))
        elif isinstance(elems[0], list):
            ret[key] = defaultdict(list)
            res = []
            for elem in elems:
                for idx, ele in enumerate(elem):
                    ret[key][str(idx)].append(torch.tensor(ele))
            for kk, vv in ret[key].items():
                res.append(torch.stack(vv))
            ret[key] = res
        else:
            ret[key] = torch.tensor(np.stack(elems, axis=0)).float()

    return ret
# Sample batch list of examples
batch_list = [
    {
        "token": [1, 2, 3],
        "point1": np.array([[1.0, 2.0], [3.0, 4.0]]),
        "point2": np.array([[5.0, 6.0]]),
        "nested_list": [[1, 2], [3, 4]],
        "value": np.array([1.0, 2.0])
    },
    {
        "token": [4, 5, 6],
        "point1": np.array([[7.0, 8.0]]),
        "point2": np.array([[9.0, 10.0], [11.0, 12.0]]),
        "nested_list": [[5, 6], [7, 8]],
        "value": np.array([3.0, 4.0])
    }
]

# Using the collate function
collated_batch = collate(batch_list)

# Display the collated result
for key, value in collated_batch.items():
    print(f"{key}: {value}")
token: [[1, 2, 3], [4, 5, 6]]
point1: tensor([[0., 1., 2.],
        [0., 3., 4.],
        [1., 7., 8.]], dtype=torch.float64)
point2: tensor([[ 0.,  5.,  6.],
        [ 1.,  9., 10.],
        [ 1., 11., 12.]], dtype=torch.float64)
nested_list: [tensor([[1, 2],
        [5, 6]]), tensor([[3, 4],
        [7, 8]])]
value: tensor([[1., 2.],
        [3., 4.]])

Build DataLoader

The build_dataloader function is a utility for creating a PyTorch DataLoader with added support for distributed training. Here’s a breakdown of what the function does:

  1. Distributed Training Support:
    • The function first checks if distributed training is initialized using dist.is_initialized(), if distributed training is active, it retrieves the rank and world size of the current process using dist.get_rank() and dist.get_world_size().
    • It then creates a DistributedSampler, which ensures that each process gets a different subset of the dataset. This sampler is used to handle data loading in a distributed manner.
    • If distributed training is not initialized, it defaults to using no sampler.
  2. Creating the DataLoader:
    • The function creates a DataLoader using the provided dataset, batch size, number of workers, shuffle, and pin memory options.
    • It uses the sampler if one was created; otherwise, it shuffles the data if shuffle is set to True.

Parameters Abstracted from PyTorch Direct Implementation

The function abstracts away the following details from a direct PyTorch DataLoader implementation: - DistributedSampler: Automatically handles creating and using a DistributedSampler when distributed training is initialized. - Sampler Management: Abstracts the logic for deciding when to use a sampler and whether to shuffle the data. - Collate Function: Assumes a specific collate_fn (collate) is used, simplifying the DataLoader creation by not requiring the user to specify it.

Limitations

  • Fixed Collate Function: The function uses a predefined collate_fn. If a different collate function is needed, the user must manually modify the function.
  • Limited Customization: The function only exposes a subset of possible DataLoader parameters (batch size, number of workers, shuffle, and pin memory). For more advanced customization, the user might need to modify the function or revert to directly creating a DataLoader. PyTorch DataLoader supports advanced features such as persistent_workers, worker_init_fn, and timeout. The function does not expose these features, limiting its flexibility for more complex use cases.
  • Distributed Training Dependency: The function relies on PyTorch’s distributed package (torch.distributed) to determine if distributed training is initialized. If used in a non-distributed context without the appropriate setup, the distributed checks and sampler creation might add unnecessary complexity.

Further Enhancements

Some potential enhancements to the function include:

  • Custom Collate Function: Allow users to specify a custom collate_fn for more flexibility in data processing.
  • Expose Advanced DataLoader Parameters: Provide additional parameters for more advanced DataLoader configurations using **kwargs.

source

build_dataloader

 build_dataloader (dataset, batch_size=4, num_workers=8,
                   shuffle:bool=False, pin_memory=False)

This function is designed to build a DataLoader object for a given dataset with optional distributed training support.

Type Default Details
dataset Dataset object
batch_size int 4 Batch size
num_workers int 8 Number of workers
shuffle bool False Shuffle the data
pin_memory bool False Pin memory
Exported source
def build_dataloader(dataset, # Dataset object
                     batch_size=4, # Batch size
                     num_workers=8, # Number of workers
                     shuffle:bool=False, # Shuffle the data
                     pin_memory=False # Pin memory
                     ): # A PyTorch DataLoader instance with the specified configuration.
    """This function is designed to build a DataLoader object for a given dataset with optional distributed training support."""
    if dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        sampler = DistributedSampler(
            dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
    else:
        sampler = None

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None and shuffle),
        num_workers=num_workers,
        collate_fn=collate,
        pin_memory=pin_memory,
    )

    return data_loader
train_dataset = pillarnext_dataset.NuScenesDataset("infos_train_10sweeps_withvelo_filterZero.pkl",
                                "/root/nuscenes-dataset/v1.0-mini",
                                10,
                                class_names=[["car"], ["truck", "construction_vehicle"], ["bus", "trailer"], ["barrier"], ["motorcycle", "bicycle"], ["pedestrian", "traffic_cone"]],
                                resampling=True)

train_loader = build_dataloader(train_dataset)
print(f"Number of batches: {len(train_loader)}")
Number of batches: 303