目录
背景和需求
一、模型的参数量统计
二、模型检查点大小查看
三、检查点大小和模型参数量之间的关系
总结
背景和需求
一个Pytorch模型的大小可以从两个方面来衡量:检查点大小和模型的参数量。现在我从两个方面都拿到了具体数值,想要验证它们两个是否一致,但在此过程中遇到了一些问题,在此奉上自己的一些观察和思考。
提示:以下是本篇文章正文内容,下面案例可供参考
一、模型的参数量统计
这里参考了博客深度学习模型参数量以及FLOPs计算工具_非晚非晚的博客-CSDN博客_flops计算工具
具体我使用如下代码计算参数的个数(注意是个数,单位为个,而非字节或比特位数!):
with torch.no_grad():
params = sum(p.numel() for p in trainer.gen_a.parameters() if p.requires_grad)
params += sum(p.numel() for p in trainer.gen_b.parameters() if p.requires_grad)
print("learneable params=", params)
得到结果是:
learneable params= 5141804
二、模型检查点大小查看
在Linux服务器上已经得到了一个训练完毕的Pytorch模型,并以.pt的格式保存为一个检查点。
使用ll命令查询gen_00009990.pt文件的字节数:
(这里补充一句,如果使用ll -h命令进行单位的规范化,如得到M、G等单位,其含义应该是字节数,即MB、GB等。)
三、检查点大小和模型参数量之间的关系
使用arch命令可知,我们的Linux系统是64位的:x86_64,因此一个字符使用64位比特的空间来存储。而一个字节占8位比特,因此一个字符使用8个字节来存储。
如上所述,我们的检查点占用了20627569个字节空间,按照如上计算方式可知,这相当于含有20627569 * 8 / 64 = 2578446个参数,约等于我们统计出来的个数5141804的一半!而2578446 * 2 = 5156892,比5141804稍大一些,即20627569 * 8 / 32的计算结果与真实统计量5141804差不多。
这不禁让我想到,虽然我们的Linux系统是64位的,但文件可能是以32位的形式存储的。因为Pytorch中默认浮点数类型为torch.float32,因为2^32非常大,一般深度学习中的可学习参数也不至于到这么大。并且使用file命令也可以得知,检查点文件本质上也是一种Zip压缩文件,因此很有可能为了节省存储空间,系统自动使用32位来存储这个文件,这也与上述观察相符合。
不过我目前并没有找到指令能够验证这个猜想,并且通过检查点计算出来的参数个数之所以比代码统计出来的稍微多一些,是因为检查点中除了保存模型中的各个参数,还保存了一些其他信息,例如字典类型的键、模型的结构信息等等。
总结
虽然Linux系统是64位的,但Pytorch模型的检查点可能是使用32位保存的。