根据pytorch张量部分中提到的机器学习模型训练的五大步骤:数据、模型、损失函数、优化器和迭代训练。我们这部分主要介绍模型训练的第一个步骤:数据模块。其中,数据模块通常进一步划分为四个子模块,分别为:数据收集、数据划分、数据读取和数据预处理。 这次主要介绍的部分是数据模块中的数据读取子模块DataLoader,其中DataLoader还可以划分为Sampler和DataSet,Sampler是用来生成索引,而DataSet是根据索引获取Img和Label。
torch.utils.data.DataLoader: 功能:构建可迭代的数据装载器
dataset:Dataset类,决定数据从哪儿读取及如何读取batchsize:批大小num_works:是否多进程读取数据shuffle:每个epoch是否乱序drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据 接下来看看在模型训练中必不可少的epoch、Iteration和Batchsize三者都代表什么:Epoch:所有训练样本都已经输入到模型中,称为一个EpochIteration:一批样本输入到模型中,称之为一个IterationBatchsize:批大小,决定一个Epoch有多少个Iteration比如,我们有样本总数为80,其中设置Batchsize为8,那么有:1 Epoch=10 Iteration;如果样本总数为87,这时其无法被Batchsize整除,那么有如下两种方式解决: (1)1 Epoch=10 Iteration 需要设置drop_last=True(余下的7个样本将被忽略) (2)1 Epoch=11 Iteration 需要设置drop_last=False
torch.utils.data.DataSet: 功能:DataSet抽象类,所有自定义的DataSet需要继承它,并且复写__getitem__(),其中__getitem__()的作用是接收一个索引,返回一个样本。 在通过人民币二分类问题进行学习之前,我们需要考虑三个问题:
读哪些数据?从哪读数据?怎么读数据?首先我们是先将人民币图像压缩包文件进行解压,之后在通过python的os模块对图像进行读取,这个过程中还需对数据集进行划分,将数据集进一步划分为训练集、验证集和测试集(train、valid、test),这也就解决了数据读取的第一个问题,即读哪些数据。之后我们再根据划分好的训练集、验证集和测试集依次对人民币图像数据进行读取,这部分是解决了数据读取的第二个问题,即从哪读数据。 最主要的部分是构建MyDataset实例,这里的RMBDataset是需要我们自己构建的,其中传入的两个参数分别为读取数据的文件路径,以及用于数据预处理的transform,这个transform之后会详细介绍。 按住Ctrl键再单击RMBDataset可以使程序跳转到具体的实现类: 可以看到,在RMBDataset实现类里有用于初始化的__init__()方法,以及__getitem__()函数会根据索引值返回一个图像和对应的标签值,len()函数可以用来查看数据的长度,即样本的数量。这里我们是构建两个Dataset,一个是训练数据的Dataset,另一个是验证数据的Dataset,有了Dataset我们就可以构建数据装载器DataLoader了。在DataLoader中我们需要传入之前创建好的Dataset,同时还有batch_size,也就是批处理的样本量,还有是否将数据进行打乱,在训练集中需要将数据进行shuffle而在验证集中则不用。之后我们采用经典的深度神经网络模型LeNet来解决这个人民币二分类问题,接下来我们通过代码的调试来观察pytorch是如何读取数据的: 首先在循环读取数据起始位置设置断点并进行Debug: 之后选择stepinto就能跳转到这个函数中, 可以看到首先函数会获取indices也就是索引,以及batch也就是数据,这个函数也正好告诉了我们数据读取问题中读取哪些数据的问题。之后我们将光标放在获取indices这一行代码上并点击Run to Cursor,将代码运行到光标这一行。接着我们再stepinto光标所在行的代码,看看pytorch是如何获取indices的,进入到了Sampler: Sampler就是一个采样器,它用来告诉我们每一个batchsize该读取哪些数据。尤其是在__iter__()函数中,这个函数是用产生迭代索引值的,也就是指定每个step需要读取哪些数据。 之后我们再stepover,stepout跳出这个函数,并将光标放在indices下面那一行,即可看到已获取的索引值:
可以看到,已经获得了indices索引值,由于我们的batch_size设置的是16,所以获得的索引数量也就为16。在得到索引之后,就可以根据索引获取相应的数据了。 再在data这一行设置断点,并stepinto查看究竟在这个函数中究竟发生了什么: 可以看到我们stepinto到了__getitem__()这个函数,并根据索引值获取相应的img和label,而获取数据的关键点就在这里。 stepout跳出这个函数之后,我们可以发现函数collate_fn(),这个函数是对获取到的数据进行整理,整理成一个batch: 经过代码调试之后,我们可以更好地解决之前数据读取的三个问题: 读哪些数据?是根据Sampler输出的Index进行读取;从哪读数据?是根据Dataset中的data_dir读取,简单地说就是从相关的硬盘文件读取;怎么读数据?这个就是根据比较重要的Dataset中的getitem函数来读取了。将这一数据读取过程用流程图表示如下: 首先我们根据for循环进入DataLoader,之后判断是采用单进程还是多进程进入DataLoaderIter,进入到DataLoaderIter之后我们会使用Sampler去获取index(indices)索引,拿到索引之后给到DatasetFetcher,在DatasetFetcher中会调用Dataset,Dataset根据我们给定的索引在getitem当中从硬盘当中去读取我们实际的图像和标签,在读取了一个batchsize大小的数据之后,通过一个collate_fn将获取的这些数据进行整理,整理成一个batchdata的形式,然后即可输入到我们的模型当中去训练了。
在这部分,我们需要安装计算机视觉工具包——torchvision,其中常用的几个方法包括:
torchvision.transforms:常用的图像预处理方法torchvision.datasets:常用数据集的datasets实现,包括MNIST、CIFAR-10、ImageNet等torchvision.model:常用的模型预训练,AlexNet,VGG,ResNet,GoogLeNet等torchvision.transforms:常用的图像预处理方法,主要包括数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换。 这里我们依然采用人民币二分类来对transforms进行说明: 这里的Compose是将我们的一系列transforms方法进行有序的组合,在具体实现的时候会依次地将这些方法对图像进行操作。在train_transform那里,可以看到首先先对图像进行resize把图像缩放到32x32大小的尺寸,接着对图像进行一个随机裁剪,再使用ToTensor将图像转换成张量形式,同时会进行一个归一化操作,把像素值的区间从0-255归一化到0-1,最后一步是标准化将均值变为0,标准差变为1。 现在依然通过代码调试的方式来查看其运行机制,先通过设置断点并stepinto到如下这个函数: 之后我们在self.transforms处进行Run on Cursor并stepinto,即可进入到transforms.py文件中的__call__()函数: 可以看到__call__()函数里是一个for循环,也就是依次地执行对图像的一系列变换操作,这个循环结束后最终返回执行完变换的图像。从这里可以看到,transform是在__getitem__()函数中被调用,再通过__getitem__()返回一个样本,之后不断根据索引总共获取batchsize大小的数据。这就是pytorch数据读取和transform的运行机制。 这次把transform加入到数据读取流程图中。
transforms.Normalize: 功能:逐channel的对图像进行标准化 output = (input - mean) / std
mean:各通道的均值std:各通道的标准差inplace:是否原地操作 对数据进行标准化之后可以加快模型的收敛。数据增强(Data Augmentation):数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。 裁剪——transforms.CenterCrop: 功能:从图像中心裁剪图片
size:所需裁剪图片尺寸随机裁剪——transforms.RandomCrop: 功能:从图片中随机裁剪出尺寸为size的图片
size:所需裁剪图片尺寸
padding:设置填充大小 当padding=a时,上下左右均填充a个像素 当padding=(a,b)时,上下填充b个像素,左右填充a个像素 当padding=(a,b,c,d)时,左,上,右,下分别填充a,b,c,d
pad_if_need:若图像小于设定size,则填充
padding_mode:填充模式,有4种模式 1.constant:像素值由fill设定 2.edge:像素值由图像边缘像素设定 3.reflect:镜像填充,最后一个像素不镜像 4.symmetric:镜像填充,最后一个像素镜像
fill:constant时,设置填充的像素值 transforms——Crop RandomResizedCrop: 功能:随机大小、长宽比裁剪图片
size:所需裁剪图片尺寸
scale:随机裁剪面积比例,默认(0.08,1)
ratio:随机长宽比,默认(3/4,4/3)
interpolation:插值方法 PIL.Image.NEAREST PIL.Image.BILINEAR PIL.Image.BICUBIC FiveCrop、TenCrop: 功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或垂直镜像获得10张图片。
size:所需裁剪图片尺寸
vertical_flip:是否垂直翻转
1.RandomHorizontalFlip 2.RandomVerticalFlip 功能:依概率水平(左右)或垂直(上下)翻转图片
p:翻转概率 3.RandomRotation 功能:随机旋转图片
degrees:旋转角度 当degrees=a时,在(-a,a)之间选择旋转角度 当degrees=(a,b)时,在(a,b)之间选择旋转角度
resample:重采样方法
expand:是否扩大图片,以保持原图信息
center:旋转点位置,默认中心旋转
1.Pad 功能:对图片边缘进行填充
padding:设置填充大小 当padding=a时,上下左右均填充a个像素 当padding=(a,b)时,上下填充b个像素,左右填充a个像素 当padding=(a,b,c,d)时,左,上,右,下分别填充a,b,c,d
padding_mode:填充模式,有4种模式,constant,edge,reflect和symmetric
fill:constant时,设置填充的像素值(R,G,B) 或(Gray)
2.ColorJitter 功能:调整亮度、对比度、饱和度和色相
brightness:亮度调整因子 当brightness=a时,从[max(0,1-a),1+a]中随机选择 当brightness=(a,b)时,从[a,b]中选择
contrast:对比度参数,同brightness
saturation:饱和度参数,同brightness
hue:色相参数 当hue=a时,从[-a,a]中选择参数,注:0<=a<=0.5 当hue=(a,b)时,从[a,b]中选择参数,注:-0.5<=a<=b<=0.5
3.Grayscale 4.RandomGrayscale
功能:依概率将图片转换为灰度图
num_output_channels:输出通道数,只能设1或3p:概率值,图像被转换为灰度图的概率5.RandomAffine 功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
degrees:旋转角度设置translate:平移区间设置,如(a,b),a设置宽(width),b设置高(height),图像在宽维度平移的区间为-img_width * a< dx<img_width *ascale:缩放比例(以面积为单位)fill_color:填充颜色设置shear:错切角度设置有水平错切和垂直错切resample:重采样方式,有NEAREST,BILINEAR,BICUBIC6.RandomErasing 功能:对图像进行随机遮挡(接收张量,是对Tensor进行操作的)
p:概率值,执行该操作的概率scale:遮挡区域的面积ratio:遮挡区域的长宽比value:设置遮挡区域的像素值,(R,G,B)或(Gray)7.transforms.Lambda 功能:用户自定义lambda方法
lambda:lambda匿名函数 lambda [arg1,[,arg2,…,argn]]:expression1.transforms.RandomChoice 功能:从一系列transforms方法中随机挑选一个 2.transforms.RandomApply 功能:依据概率执行一组transforms操作 3.transforms.RandomOrder 功能:对一组transforms操作打乱顺序
自定义transforms要素:
仅接收一个参数,返回一个参数注意上下游的输出与输入 通过类实现多参数传入: 椒盐噪声: 椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点,白点称为盐噪声,黑色称为椒噪声。 信噪比(Signal-Noise Rate,SNR)是衡量噪声的比例,图像中为图像像素的占比。transforms methods总结:
数据增强实战应用 原则:让训练集与测试集更接近
空间位置:平移色彩:灰度图,色彩抖动形状:仿射变换上下文场景:遮挡、填充… …深度之眼训练营——Pytorch课程