解决pytorch模型加载时gpu id的限制

tech2023-01-13  101

解决pytorch模型加载时gpu id的限制

从接触pytorch很长时间来,发现每次调用训练好的模型时,总是被原来训练时使用的第几个gpu限制。如:我训练模型时用的是第3号gpu; 测试模型时,常规调用就会被限制使用第3号gpu才能运行。

为此困扰良久,经过师兄指导后发现可以这么做:

代码示例如下:

net = XXXXXX net.load_state_dict(torch.load(XXX.pth.gz, map_location="cpu")) # 先将模型放到CPU net.eval() net.to("cuda:0") # 然后再使用GPU

 

       
最新回复(0)