pytorch版本的SRCNN代码一共分为6个.py文件,结构如下:
datasets.pymodels.pyprepare.pyutils.pytest.pytrain.py以上文件不分先后,执行时通过import…或者from…import…语句进行调用。以下解释import部分均省略,个别例外。
readme.md中给出了不同放大倍数下的训练数据,验证数据和测试数据的下载地址。如果下载了直接把对应的路径写好就可以执行了,这里我们使用自己下载的数据通过使用prepare.py来制作训练和验证的h5格式的数据集。
import argparse import glob import h5py import numpy as np import PIL.Image as pil_image from utils import convert_rgb_to_y #该函数用来创建自己的h5数据,包括俩个函数:对训练数据的处理和验证部分的处理。 def train(args): h5_file = h5py.File(args.output_path, 'w') ''' def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output 的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入 ''' lr_patches = [] hr_patches = [] ''' 创建俩个空列表:lr_patches和hr_patches(通过ctrl左键该变量名查看在其他位置的引用) ''' for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))): ''' 这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点: 1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径 2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回 3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序 4.for x in *: -->循换输出 ''' hr = pil_image.open(image_path).convert('RGB') ''' 1.***.open():是PIL图像库的函数,用来从image_path中加载图像 2.***.convert():是PIL图像库的函数,用来转换图像的模式 ''' hr_width = (hr.width // args.scale) * args.scale hr_height = (hr.height // args.scale) * args.scale hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理 hr = np.array(hr).astype(np.float32) lr = np.array(lr).astype(np.float32) hr = convert_rgb_to_y(hr) lr = convert_rgb_to_y(lr) ''' " / " 表示浮点数除法,返回浮点结果;" // " 表示整数除法,返回不大于结果的一个最大的整数,也就是向下取整 这里的hr是输入的原图,先进行mod和缩放的预处理,lr是hr在mod之后经过scale的结果,得到的lr再经过缩放处理得到最终要用的lr的图片 resize():缩放操作 np.array():将列表list或元组tuple转换为ndarray数组 astype():转换数组的数据类型 convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片 假设原始输入图像为(321,481,3)-->依次为高,宽,通道数 1.先mod,之后hr的图像尺寸为(320,480,3) 2.对hr图像进行双三次上采样放大操作 3.将hr//scale进行双三次上采样放大操作之后×scale得到lr 4.接着进行通道数转换和类型转换 ''' for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride): ''' 图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480; shape[2]是指图像的通道数 ''' for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride): lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size]) hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size]) lr_patches = np.array(lr_patches) hr_patches = np.array(hr_patches) #把得到的数据转化为数组类型 h5_file.create_dataset('lr', data=lr_patches) h5_file.create_dataset('hr', data=hr_patches) h5_file.close() def eval(args): h5_file = h5py.File(args.output_path, 'w') lr_group = h5_file.create_group('lr') hr_group = h5_file.create_group('hr') for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): hr = pil_image.open(image_path).convert('RGB') hr_width = (hr.width // args.scale) * args.scale hr_height = (hr.height // args.scale) * args.scale hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) hr = np.array(hr).astype(np.float32) lr = np.array(lr).astype(np.float32) hr = convert_rgb_to_y(hr) lr = convert_rgb_to_y(lr) lr_group.create_dataset(str(i), data=lr) hr_group.create_dataset(str(i), data=hr) h5_file.close() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--images-dir', type=str,default='/home/dushuai/word/SRCNN_pytorch/evaldata') parser.add_argument('--output-path', type=str,default='/home/dushuai/word/SRCNN_pytorch/evalout/evalout.h5') parser.add_argument('--patch-size', type=int, default=33) parser.add_argument('--stride', type=int, default=14) parser.add_argument('--scale', type=int, default=2) parser.add_argument('--eval', action='store_true') args = parser.parse_args() if not args.eval: train(args) else: eval(args) ''' 最后这个if..else..要注意一下,是和parser传入的最后一个参数有关的,它是用来决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。该参数的具体使用方法如下 '''在我看来这是个很鸡肋的参数设置,但是存在即合理,我们只需要明白它就ok了。
import argparse parser = argparse.ArgumentParser() parser.add_argument('--eval', action='store_false') args = parser.parse_args() def main(): x = args.eval print(x) if __name__ == '__main__': main()可以看到我上边的action=‘store_false’,但是边一个是直接在IDE中run的结果是True,而我通过命令行运行得到的结果却是false,这是为什么? 顾名思义,store_flase就是存储一个bool值false,也就是说在该参数在被激活时它会输出store存储的值也就是这里我通过命令行得到的值,而IDE得到的值没有激活该参数,得到的是它的默认值True.
import argparse parser = argparse.ArgumentParser() parser.add_argument('--eval', action='store_false') args = parser.parse_args() def a(): print('a') def b(): print('b') def main(): x = args.eval print(x) if not args.eval: print(args.eval) a() else: print(args.eval) b() if __name__ == '__main__': main()在SRCNN的预处理中可以通过修改action中store的值也可以通过if not args.eval来调整函数运行哪个函数来得到对应的结果。
一共包含俩个类TrainDataset()和EvalDataset(),分别用来加载prepare.py制作的训练和验证俩个数据集的。这部分想自己写,但是发现了一篇不错的博客,传送门在此
这部分更为简单,首先定义了模型类SRCNN,它继承自父类nn.Module。super这句是对继承自父类的属性进行初始化。接下来就是对卷积层的定义和前向传播的定义。
这个utils.py相当于是工具类,定义网络需要使用的各种函数。这个文件一共包括了四个函数和一个类,至于test和train都很简单,很容易看懂,略
参考文献: 1.if name == ‘main’: 2.Python之argparse