开启torch新篇章:Pytorch创建Dataset,并加载DataLoader

tech2024-07-13  64

这是首篇关于siamfc++中dataloader的实现,包括接下来的三篇文章都是,需要连续看。

torch.utils.data.Dataset

Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示: 其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。 在继承了这个Dataset类之后,我们需要实现的核心功能便是__getitem__()函数,getitem()是Python中类的默认成员函数,我们通过实现这个成员函数实现可以通过索引来返回图像数据的功能。

实例

首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取。

class ShipDataset(Dataset): """ root:图像存放地址根路径 augment:是否需要图像增强 """ def __init__(self, root, augment=None): # 这个list存放所有图像的路径 self.image_files = np.array([x.path for x in os.scandir(root) if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")] def __getitem__(self, index): # 读取图像数据并返回 return cv2.imread(self.image_files[index]) def __len__(self): # 返回图像的数量 return len(self.image_files)

torch.utils.data.DataLoader

之前所说的Dataset类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:

可以分批次读取:batch-size可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序可以并行加载数据(利用多核处理器加快载入数据的效率)

这时候就需要Dataloader类了,Dataloader这个类并不需要我们自己设计代码,我们只需要利用DataLoader类读取我们设计好的Dataset即可:

# 利用之前创建好的ShipDataset类去创建数据对象 ship_train_dataset = ShipDataset(data_path, augment=transform) # 利用dataloader读取我们的数据对象,并设定batch-size和工作现场 ship_train_loader = DataLoader(ship_train_dataset, batch_size=16, num_workers=4, shuffle=False, **kwargs)

这时候通过ship_train_loader返回的数据就是按照batch-size来返回特定数量的训练数据的tensor,而且此时利用了多线程,读取数据的速度相比单线程快很多。

我们这样读取:

for image in train_loader: image = image.to(device) # 将tensor数据移动到device当中 optimizer.zero_grad() output = model(image) # model模型处理(n,c,h,w)格式的数据,n为batch-size

siamfc++

main/train.py首次定义dataset

dataloader = dataloader_builder.build(task, task_cfg.data)

videoanalyst/data/builder.py创建并加载了Dataset

def build(task: str, cfg: CfgNode, seed: int = 0) -> DataLoader: r""" Arguments --------- task: str task name (track|vos) cfg: CfgNode node name: data seed: int seed for random """ if task in ["track", "vos"]: # build dummy dataset for purpose of dataset setup (e.g. caching path list) logger.info("Build dummy AdaptorDataset") dummy_py_dataset = AdaptorDataset( task, cfg, num_epochs=cfg.num_epochs, nr_image_per_epoch=cfg.nr_image_per_epoch, seed=seed, ) logger.info("Read dummy training sample") dummy_sample = dummy_py_dataset[0] # read dummy sample del dummy_py_dataset, dummy_sample gc.collect(generation=2) logger.info("Dummy AdaptorDataset destroyed.") # get world size in case of DDP world_size = dist_utils.get_world_size() # build real dataset logger.info("Build real AdaptorDataset") py_dataset = AdaptorDataset(task, cfg, num_epochs=cfg.num_epochs, nr_image_per_epoch=cfg.nr_image_per_epoch) # use DistributedSampler in case of DDP if world_size > 1: py_sampler = DistributedSampler(py_dataset) logger.info("Use dist.DistributedSampler, world_size=%d" % world_size) else: py_sampler = None # build real dataloader dataloader = DataLoader( py_dataset, batch_size=cfg.minibatch // world_size, shuffle=False, pin_memory=cfg.pin_memory, num_workers=cfg.num_workers // world_size, drop_last=True, sampler=py_sampler, ) return dataloader

videoanalyst/data/adaptor_dataset.py定义了Dataset具体实现

from loguru import logger import torch import torch.multiprocessing from torch.utils.data import Dataset class AdaptorDataset(Dataset): _EXT_SEED_STEP = 30011 # better to be a prime number _SEED_STEP = 10007 # better to be a prime number _SEED_DIVIDER = 1000003 # better to be a prime number def __init__( self, task, cfg, num_epochs=1, nr_image_per_epoch=1, seed: int = 0, ): self.datapipeline = None self.task = task self.cfg = cfg self.num_epochs = num_epochs self.nr_image_per_epoch = nr_image_per_epoch self.ext_seed = seed def __getitem__(self, item): if self.datapipeline is None: # build datapipeline with random seed the first time when __getitem__ is called # usually, dataset is already spawned (into subprocess) at this point. seed = (torch.initial_seed() + item * self._SEED_STEP + self.ext_seed * self._EXT_SEED_STEP) % self._SEED_DIVIDER self.datapipeline = datapipeline_builder.build(self.task, self.cfg, seed=seed) logger.info("AdaptorDataset #%d built datapipeline with seed=%d" % (item, seed)) training_data = self.datapipeline[item] return training_data def __len__(self): return self.nr_image_per_epoch * self.num_epochs
最新回复(0)