论文代码复现,对pytorch训练模型代码的总体解读(以堆叠注意力网络模型为例))

tech2022-07-05  159

以论文Stacked Attention Networks for Image Question Answering提到的堆叠注意力模型为例 , 对pytorch训练模型代码的总体解读 pytorch模型训练步骤: 1,load data和preprocess 2, build model(本文SAN) 3, train 训练神经网络,其实就是在调参,训练完成后得到一组很好的参数(可以称之为训练后的模型),然后将这个模型应用到其他样本中去识别其他样本。

一,preprocess

加载数据集 数据集使用的是MSCOCO,用代码从官网下载比较慢,可以提前下载好。 读取数据集 PyTorch 读取图片,主要是通过 Dataset 类,Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。 读取自己数据的基本流程就是:

获取图片的路径和标签信息。将这些信息转化为 list,该 list 每一个元素对应一个样本。通过 getitem 函数,读取数据和标签,并返回数据和标签。

在训练代码里有些操作感觉不到,只会看到通过 DataLoader 就可以获取数据,因此,在 PyTorch 代码中,能读取自己的数据集,只需要两步: 1, 制作图片数据的索引 就是读取图片路径,标签. 2, 构建 Dataset 子类 进行数据集预处理

# 对数据集进行处理 class MSCOCODataset(td.Dataset): def __init__(self, images_dir, q_dir, ans_dir, mode='train', image_size=(448, 448), top_num=1000): #用one-hot对每句话提取特征向量 def one_hot_answer(self, inp, mapping): return torch.Tensor([mapping[inp]]) #由于pytorch中没有string格式,使用onehot对答案和问题进行编码。 def one_hot_question(self, inp, mapping): vec = torch.zeros(len(inp.split(" "))) for i, word in enumerate(inp.split(" ")): vec[i] = mapping[word] return vec def __len__(self): return len(self.top_questions) def __getitem__(self, idx): #返回一张图片的数据

最主要的是getitem 函数:getitem 接收一个索引,然后返回一张图片数据的路径和标签信息。 在getitem 函数中,首先,根据idx获取问题,答案,图像相关信息。然后,利用 Image.open 对图片进行读取,最后再对图片进行处理,

def __getitem__(self, idx): #返回一条数据 q = self.top_questions[idx] a = self.top_answers[idx] img_id = self.top_images[idx] img_path = os.path.join(self.root_image+"/"+"COCO_%s2014_%s.jpg" % (self.mode, img_id.zfill(12))) #利用 Image.open 对图片进行读取: img = Image.open(img_path).convert("RGB") #如果不使用.convert('RGB')进行转换的话,读出来的图像是RGBA四通道的,A通道为透明通道 # 最后再对图片进行处理,这个 transform 里边还可以实现随机裁剪,旋转,翻转,放射变换等操作对图片进行处理。 transform = tv.transforms.Compose([tv.transforms.CenterCrop(self.image_size), #按image_size中心裁剪 tv.transforms.ToTensor(), #转化为Tensor格式 tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #标准化,方便网络去优化 x = transform(img) one_hot_q = self.one_hot_question(q, self.vocab_q) one_hot_ans = self.one_hot_answer(a, self.vocab_a) target_q = torch.zeros(self.seq_question) target_q[:one_hot_q.shape[0]] = one_hot_q return x, target_q, len(one_hot_q), one_hot_ans

对问题和答案的数据集的预处理包括去除标点符号,空格替换等操作。相关方法函数可以参考python 中re.sub,replace(),strip()的区别

上面的Dataset的子类MSCOCODataset读取数据,是通过数据加载器 DataLoder触发的,本例在train.py中使用的DataLoder

#数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。 #在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。 #直至把所有的数据都抛出。就是做一个数据的初始化。 self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn) self.val_loader = torch.utils.data.DataLoader(self.train_set, batch_size=batch_size, sampler=val_sampler, collate_fn=collate_fn)

二,model

模型的定义过程就是先继承,再构建组件,最后组装。 首先,必须继承 nn.Module 这个类,要让 PyTorch 知道这个类是一个Module。 其次,在__init__(self)中设置好需要的“组件"(如 conv、pooling、Linear、BatchNorm等)。其中基本组件可从 torch.nn 中获取,或者从 torch.nn.functional 中获取, 最后,在 forward(self, x)中使用定义好的“组件”进行组装,来搭建网络结构,这样一个模型就定义好了。

在本例中:

class VGGNet(nn.Module): #图像模型使用VGGNet def __init__(self, output_features, fine_tuning=False): class LSTM(nn.Module): #问题模型用到了LSTM def __init__(self, vocab_size, embedding_dim, batch_size, hidden_dim, num_layers=1):

lstm和VGGNet都是已有的模型,直接写就行了,不用改太多,这两个模型,在train.py中调用时,还会再设置参数。

class AttentionNet(nn.Module): #用于论文中的堆叠注意力网络模型 def __init__(self, num_classes, batch_size, input_features=1024, output_features=512): ,,, def forward(self, image, question): # image_vec = batchx196x1024 # question_vec = batchx1024 irep_1 = self.image1(image) #图片向量 qrep_1 = self.question1(question).unsqueeze(dim=1) #问题向量 ha_1 = self.tanh(irep_1 + qrep_1) #图像与问题结合 ha_1 = self.dropout(ha_1) pi_1 = self.softmax(self.attention1(ha_1)) #生成区域注意力分布概率 u_1 = (pi_1 * image).sum(dim=1) + question #每个图像区域的概率与该图像向量相乘得到最新图像向量 #新的图像向量与问题结合,形成新的查询向量u_1 irep_2 = self.image2(image) qrep_2 = self.question2(u_1).unsqueeze(dim=1) #原来的问题向量变成了新的查询向量u_1 ha_2 = self.tanh(irep_2 + qrep_2) ha_2 = self.dropout(ha_2) pi_2 = self.softmax(self.attention2(ha_2)) u_2 = (pi_2 * image).sum(dim=1) + u_1 # 使用了两层注意层,论文指出两层注意层效果最好。 w_u = self.answer_dist(self.dropout(u_2)) # 推断答案 return w_u

三,train

对于一个普通的训练模型,train的主要功能就是利用网络训练模型,优化参数并计算损失loss和精确度acc,流程如下: 1,out=net(x) #将输入通过网络得到输出, 2,loss= CrossEntropyLoss(out,y) # 用损失函数计算输出与真实值的差距, 3,self.optimizer.zero_grad() # 计算梯度之前将梯度清零, 4,loss.backward() # 通过前传,后传操作,梯度计算, 5,self.optimizer.step() # 使用优化器进行梯度更新, 最后得到一组更好的参数,这只是一次过程,还要重复很多次epoch。 训练神经网络,其实就是在调参,训练完成后得到一组很好的参数(可以称之为训练后的模型),然后将这个模型应用到其他样本中去识别其他样本。

而对于复杂一点的模型或数据集的训练,在训练之前还要进行数据集的处理。 本例中:

# train and validation splits train_ind = self.indices[:int(len(self.indices)*0.8)] #训练集和验证集按4:1分配 val_ind = self.indices[int(len(self.indices)*0.8):] train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_ind) #无放回地按照给定的索引列表采样样本元素 val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_ind) #数据加载器,结合了数据集和取样器,处理数据集。 self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn) self.val_loader = torch.utils.data.DataLoader(self.train_set, batch_size=batch_size, sampler=val_sampler, collate_fn=collate_fn) #初始化模型,给出了一些参数值 self.image_model = VGGNet(output_features=1024).to(self.device) self.question_model = LSTM(vocab_size=len(self.train_set.vocab_q), embedding_dim=1000, batch_size=batch_size, hidden_dim=1024).to(self.device) #图像和问题都是1024维,通过关注层变成512维。 self.attention = AttentionNet(num_classes=1000, batch_size=batch_size, input_features=1024, output_features=512).to(self.device) #参数组 self.optimizer_parameter_group = [{'params': self.question_model.parameters()}, {'params': self.image_model.parameters()}, {'params': self.attention.parameters()}] pytorh大概有17种损失函数和10种优化器,本例中使用的是损失函数是交叉熵函数和优化器是RMSprop均方根优化器。 self.criterion = nn.CrossEntropyLoss() #交叉熵函数将输入经过 softmax 激活函数之后,再计算其与 target 的交叉熵损失。 self.optimizer = torch.optim.RMSprop(self.optimizer_parameter_group, lr=4e-4, alpha=0.99, eps=1e-8, momentum=0.9) #RMS 是均方根(root meam square)的意思。RMSprop 采用均方根作为分母,可缓解 Adagrad 学习率下降较快的问题,并且引入均方根,可以减少摆动

参考

1,论文链接 1,python 中re.sub,replace(),strip()的区别 2,onehot 3,init.xavier_uniform()的用法 https://blog.csdn.net/luoxuexiong/article/details/95772045 4,快速上手笔记,PyTorch模型训练实用教程(附代码) 5, torch.utils.data.DataLoader

最新回复(0)