目录
- 加载训练好的模型
- 下载模型权重
- 图像预处理
- 打开要预测的图像
- 传递图像
- 运行模型
- 下载数据
- 打开imagenet_classes.txt
- 预测结果
- 前5个最可能分类
加载训练好的模型
pip3 install pillow
>>> from torchvision import models
>>> dir(models)
['AlexNet', 'DenseNet', 'Inception3', 'ResNet', 'SqueezeNet', 'VGG', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'alexnet', 'densenet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'inception', 'inception_v3', 'resnet', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'squeezenet', 'squeezenet1_0', 'squeezenet1_1', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']
出现 这个错误
ImportError: cannot import name 'PILLOW_VERSION' from 'PIL' (/usr/lib64/python3.11/site-packages/PIL/__init__.py)
需要将对应出错文件的PILLOW_VERSION改为__version__
PILLOW_VERSION在Pillow 7.0.0之后的版本被移除了
下载模型权重
>>> resnet=models.resnet101(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /home/spx/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100.0%
>>> resnet
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
....
图像预处理
>>>from torchvision import transforms
>>> prporocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.285,0.456,0.406],std=[0.229,0.224,0.225])]
1、定义了转换操作,允许快速定义基本预重函数的管道。
2、将输入图像缩放到256*256个像素,围绕中心将图像裁剪为224*224个像素。
3、图像转换为一个张量,对RGB分量进行归一化处理。
打开要预测的图像
>>from PIL import Image
>>> img=Image.open('/home/spx/learn/pic/1.jpg')
>>> img
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1920x1200 at 0x7F342D54FF50>
>>> img.show()
传递图像
>>> img_t=preprocess(img)
>>>> import torch
>>> batch_t=torch.unsqueeze(img_t,0)
运行模型
>>> resnet.eval()
>>>> out=resnet(batch_t)
下载数据
https://image-net.org/里找到 imagenet_classes.txt下载,这是标签文件。
或者
https://gitee.com/lonerlin/classification/blob/master/imagenet_classes.txt
https://github.com/ethereon/caffe-tensorflow/blob/master/examples/imagenet/imagenet-classes.txt
这里, 将它复制下来。
新建一个imagenet_classes.txt文件,粘贴进去
打开imagenet_classes.txt
>>> with open('/home/spx/learn/pic/imagenet_classes.txt') as f:
... labels=[line.strip() for line in f.readlines()]
...
>>> labels
['tench, Tinca tinca', 'goldfish, Carassius auratus', 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 'tiger shark, Galeocerdo cuvieri', 'hammerhead, hammerhead shark', 'electric ray, crampfish, numbfish, torpedo', 'stingray', 'cock', 'hen', 'ostrich, Struthio camelus', 'brambling, Fringilla montifringilla', 'goldfinch, Carduelis carduelis', 'house finch, linnet, Carpodacus mexicanus', 'junco, snowbird', 'indigo bunting, indigo f
....
预测结果
>>> _,index=torch.max(out,1)
>>> percentage=torch.nn.functional.softmax(out,dim=1)[0]*100
>>> labels[index[0]]
'hog, pig, grunter, squealer, Sus scrofa'
>>> percentage[index[0]].item()
60.67759323120117
pig预测正确
前5个最可能分类
>>> _,indices=torch.sort(out,descending=True)
>>> [(labels[idx],percentage[idx].item()) for idx in indices[0][:5]]
[('hog, pig, grunter, squealer, Sus scrofa', 60.67759323120117), ('weasel', 9.75589656829834), ('guinea pig, Cavia cobaya', 8.112009048461914), ('black-footed ferret, ferret, Mustela nigripes', 5.257884979248047), ('polecat, fitch, foulmart, foumart, Mustela putorius', 4.569345474243164)]