加载模型时,模型也不大,GPU内存也完全够,但就是出现这个CUDA内存溢出问题。
究其原因,在于model.load_state_dict(torch.load(‘pretrain-model.pth’, map_location=device))这个代码省略了map_location=device
通过torch.load加载预训练模型pretrain-model.pth,map_location=device 是一个参数,用于指定模型参数加载到哪个设备上
如果map_location=device 不指定,PyTorch 会根据以下规则自动决定将模型加载到哪里
这就很容易出现内存不足的情况