Pytorch分布式训练

tech2024-07-14  64

接着上一篇写到加载Dataset,这里引进如何把Dataset分布在多卡进行训练 在多卡情况下分布式训练数据的读取用到了这两个代码

torch.nn.parallel.DistributedDataParallel torch.utils.data.distributed.DistributedSampler dataparallel的做法是直接将batch切分到不同的卡。sampler确保dataloader只会load到整个数据集的一个特定子集的做法。DistributedSampler就是为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。

实例

from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel dataset = your_dataset() datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=datasampler) model = your_model()

现在我们可以完整的看siamfc++里面怎么加载数据的,videoanalyst/data/builder.py

logger.info("Build real AdaptorDataset") py_dataset = AdaptorDataset(task, cfg, num_epochs=cfg.num_epochs, nr_image_per_epoch=cfg.nr_image_per_epoch) # use DistributedSampler in case of DDP if world_size > 1: py_sampler = DistributedSampler(py_dataset) logger.info("Use dist.DistributedSampler, world_size=%d" % world_size) else: py_sampler = None # build real dataloader dataloader = DataLoader( py_dataset, batch_size=cfg.minibatch // world_size, shuffle=False, pin_memory=cfg.pin_memory, num_workers=cfg.num_workers // world_size, drop_last=True, sampler=py_sampler, )

AdaptorDataset 构建数据集; py_sampler 定义数据集划分; dataloader 完成数据集加载

最新回复(0)