2.1 数据加载 数据的组织比较简单,按照以下格式组织:
data images 1.jpg2.jpg… labels 1.txt2.txt… 重写一下Dataset类,用于加载数据集。class KeyPointDatasets(Dataset): def init(self, root_dir="./data", transforms=None): super(KeyPointDatasets, self).init() self.img_path = os.path.join(root_dir, “images”) # self.txt_path = os.path.join(root_dir, “labels”)
self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg")) self.txt_list = [item.replace(".jpg", ".txt").replace( "images", "labels") for item in self.img_list] if transforms is not None: self.transforms = transforms def __getitem__(self, index): img = self.img_list[index] txt = self.txt_list[index] img = cv2.imread(img) if self.transforms: img = self.transforms(img) label = [] with open(txt, "r") as f: for i, line in enumerate(f): if i == 0: # 第一行 num_point = int(line.strip()) else: x1, y1 = [(t.strip()) for t in line.split()] # range from 0 to 1 x1, y1 = float(x1), float(y1) tmp_label = (x1, y1) label.append(tmp_label) return img, torch.tensor(label[0]) def __len__(self): return len(self.img_list) @staticmethod def collect_fn(batch): imgs, labels = zip(*batch) return torch.stack(imgs, 0), torch.stack(labels, 0)返回的结果是图片和对应坐标位置。
2.2 网络模型 import torch import torch.nn as nn
class KeyPointModel(nn.Module): def init(self): super(KeyPointModel, self).init() self.conv1 = nn.Conv2d(3, 6, 3, 1, 1) self.bn1 = nn.BatchNorm2d(6) self.relu1 = nn.ReLU(True) self.maxpool1 = nn.MaxPool2d((2, 2))
self.conv2 = nn.Conv2d(6, 12, 3, 1, 1) self.bn2 = nn.BatchNorm2d(12) self.relu2 = nn.ReLU(True) self.maxpool2 = nn.MaxPool2d((2, 2)) self.gap = nn.AdaptiveMaxPool2d(1) self.classifier = nn.Sequential( nn.Linear(12, 2), nn.Sigmoid() ) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.maxpool1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.maxpool2(x) x = self.gap(x) x = x.view(x.shape[0], -1) return self.classifier(x)其结构就是卷积+pooling+卷积+pooling+global average pooling+Linear,返回长度为2的tensor。
2.3 训练 def train(model, epoch, dataloader, optimizer, criterion): model.train() for itr, (image, label) in enumerate(dataloader): bs = image.shape[0] output = model(image) loss = criterion(output, label)
optimizer.zero_grad() loss.backward() optimizer.step() if itr % 4 == 0: print("epoch:%2d|step:%04d|loss:%.6f" % (epoch, itr, loss.item()/bs)) vis.plot_many_stack({"train_loss": loss.item()*100/bs})total_epoch = 300 bs = 10 ######################################## transforms_all = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((360,480)), transforms.ToTensor(), transforms.Normalize(mean=[0.4372, 0.4372, 0.4373], std=[0.2479, 0.2475, 0.2485]) ])
datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)
data_loader = DataLoader(datasets, shuffle=True, batch_size=bs, collate_fn=datasets.collect_fn)
model = KeyPointModel()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.MSELoss() scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(total_epoch): train(model, epoch, data_loader, optimizer, criterion) loss = test(model, epoch, data_loader, criterion)
if epoch % 10 == 0: torch.save(model.state_dict(), "weights/epoch_%d_%.3f.pt" % (epoch, loss*1000))loss部分使用Smooth L1 loss或者MSE loss均可。
MSE Loss: Absorbing material: www.goodsmaterial.com