【深度之眼】Pytorch框架班第五期-Pytorch数据读取的代码调试

tech2025-03-27  1

torch.utils.data.DataLoader

Data(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_list=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

功能: 构建可迭代的数据装载器

dataset:Dataset类,决定数据从哪读取以及如何读取batchsize:批大小num_works: 是否多进程读取数据shuffle:每个epoch是否乱序drop_list:当样本数不能被batchsize整除时,是否舍弃最后一批数据

torch.utils.data.Dataset

class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])

功能: Dataset抽象类,所有自定义的Dataset都要继承它,并且复写__getitem__()函数。 getitem: 接收一个索引,返回一个样本

Pytorch数据读取机制

读那些数据?从哪读数据?怎么读数据?

代码调试

Debug常用按钮

1、设置断点,进行Debug调试 2、采用C按钮,跳转到dataloader.py中DataLoader类的__iter__(self)函数中,该处代码表示是否使用多进程。 3、以单进程为例,点击B按钮然后点击C按钮,进入单进程的类当中,在该类中,最重要的函数为__next__(self),该函数会获取index和data。该函数告诉我们读哪些数据。 4、将光标放在345行,即index=self._next_index()上,点击F按钮,然后点击C按钮进入self._next_index()函数中,查看该函数是如何获取index的。 5、再点击以下C按钮,我们进入到sampler.py中的BatchSampler类中。Sampler就是一个采样器,他就是用来告诉我们每个Batchsize该读取那些数据

5、点击两次E按钮,跳出函数,然后点击B按钮,运行345行的代码,运行完成后我们的index就挑选出来了(Batchsize=16)。 6、有了index,接下来就是数据获取,我们点击B按钮进入self.dataset_fetcher.fetch()函数中 7、我们进入到fetch.py文件中的_MapDatasetFetcher类中,在第44行中正式调用了dataset,对dataset输入一个索引index,就会返回一个data,将一些列的data拼接为一个list。

8、先点击B按钮,运行到44行,然后点击C按钮(需要点击两次)进到self.dataset中。我们可以看到该函数跳转到了我们自己创建的my_dataset.py文件中的RMB数据集类中的__getitem__(self, index)函数中,self.data_inifo的每一项为图片的路径和标签,然后我们通过Image.open来读取图片,这就实现了一个数据的读取和标签的获取。 9、点击E按钮跳出该函数,进入到第7步的界面中,我们将光标放在47行并点击F按钮运行到该行,我们可以发现,在我们将读取的数据返回到第六步之前,我们会用self.collate_fn()函数来整理数据,该函数为数据的整理器,它会将我们读取的16个数据整理为一个Batch的形式,可以看到在运行self.collate_fn()函数之前,我们的data为list类型的数据。 10、点击两次B按钮,我们可以发现我们的data变成batch的形式,第一个元素里面为图片Tensor,第二个为标签。 11、点击F按钮返回数据并点击B按钮,此时我们可以看到我们的data为list的形式,第一个元素为图像,第二个元素为标签。有了图像和标签我们就可以对模型进行训练。这就是pytorch的数据读取机制。

现在我们回答上面的三个问题:

1、读那些数据? Sampler输出的Index 2、从哪读数据? Dataset中的data_dir 3、怎么读数据 Dataset中的getitem,根据索引读数据

流程图

最新回复(0)