pytorch使用Apex混合精度加速训练

tech2023-06-09  111

Apex官网:https://nvidia.github.io/apex/amp.html 这篇博客讲的非常好 PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速

1.安装

使用pip安装后会出错

TypeError: Class advice impossible in Python3. Use the @Implementer class decorator instead.

解决方法:

$ pip uninstall apex $ git clone https://www.github.com/nvidia/apex $ cd apex $ python setup.py install

2.使用

核心代码:

from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # “欧一”,不是“零一” with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()

例子:

原始训练代码:

import torch ngpu=2 def traiin(): model = torch.nn.Linear(D_in, D_out).cuda() model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)]) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for img, label in dataloader: out = model(img.half()) loss = LOSS(out, label) loss.backward() optimizer.step() optimizer.zero_grad() #此时采用全精度32位来训练

半精度训练:

import torch ngpu=2 def traiin(): model = torch.nn.Linear(D_in, D_out).cuda().half() model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)]) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for img, label in dataloader: out = model(img.half()) loss = LOSS(out, label) loss.backward() optimizer.step() optimizer.zero_grad() #此时采用半精度16位来训练

显存基本可以降低为原来的一半,但训练速度降低,可能原因是,CUDNN只支持float32加速,半精度后,将不能加速

混合精度训练:

import torch ngpu=2 def train(): model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) #设置混合精度模式为O1(欧1,不是零1,后面会解释各个模式区别) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)]) for img, label in dataloader: out = model(img) loss = LOSS(out, label) #将loss进行缩放,防止溢出 with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() optimizer.zero_grad() def save_model(self, epoch): if self.mixed_precision: import apex.amp as amp amp_state_dict = amp.state_dict() else: amp_state_dict = None checkpoint = { 'epoch': epoch, 'params': self.params, 'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'amp': amp_state_dict } torch.save(checkpoint, os.path.join(self.expdir,'model.pt')) def load_model(self, checkpoint): state_dict = torch.load(checkpoint) self.model.load_state_dict(state_dict['model']) if self.mixed_precision: import apex.amp as amp amp.load_state_dict(state_dict['amp'])

注意: 1.模型在amp.initialize前必须加载到GPU上。 2.amp.initialize前不能对模型进行任何分布式操作,如torch.nn.DataParallel必须放在之后。

opt_level解释O0纯 FP32 训练,可以作为 accuracy 的 baselineO1混合精度训练(推荐使用),根据黑白名单自动决定使用 FP16(GEMM, 卷积)还是 FP32(Softmax)进行计算O2几乎FP16混合精度训练,不存在黑白名单,除了 Batch Norm,几乎FP16 计算O3纯 FP16 训练,很不稳定,但是可以作为 speed 的 baseline

参考: PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速 [] Apex [官网] Apex混合精度加速 [码农网]

最新回复(0)