解决pytorch模型加载时gpu id的限制
从接触pytorch很长时间来,发现每次调用训练好的模型时,总是被原来训练时使用的第几个gpu限制。如:我训练模型时用的是第3号gpu; 测试模型时,常规调用就会被限制使用第3号gpu才能运行。
为此困扰良久,经过师兄指导后发现可以这么做:
代码示例如下:
net = XXXXXX
net.load_state_dict(torch.load(XXX.pth.gz, map_location="cpu")) # 先将模型放到CPU
net.eval()
net.to("cuda:0") # 然后再使用GPU