用torchvision.utils.save_image()保存图片时出现异常
有些像素点会显示为全黑(灰度图),如下图所示,第一张和第三张图
刚开始以为是图像数据分布范围的问题,在保存之前输出图像tensor的最大max和最小min值,出现了 -0.0x和1.0x的数值,说明图像的像素范围超出了0-1。
读源码
可是通过读utils.save_image()的源码发现,就算超出0-1也不应该出现这种问题,源码中存在如下部分代码
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
grid可以理解为图片张量,这段代码
- 首先将 grid 张量中的每个元素乘以255。这一步将原来在0到1范围内的图像数据转换到0到255的范围内
- 对 grid 张量中的每个元素加上0.5。这一步可能是为了进行亮度调整或将值偏移至正数范围内
- 将 grid 张量中的每个元素限制在0到255的范围内。小于0的值将被设置为0,大于255的值将被设置为255
- 后面的不重要
源码在将所有像素乘255之后,已经将数据每个像素范围限制在了0-255之间。
问题解决
经过查看其他成功的代码和源码中的注释发现。大多在使用 torchvision.utils.save_image时直接将4Dtensor图片和保存路径传入给 save_image()函数就行,不会出现问题。
且utils.save_image接收四维tensor ,B C H W
如源码所示
而我在保存之前进行了降维处理,降成了三维(squeeze(0)是降维)
于是删掉后面的squeeze(0),问题解决
结果如图所示