一、ResNet50工具箱安装
(1)下载工具箱
https://ww2.mathworks.cn/matlabcentral/fileexchange/64626-deep-learning-toolbox-model-for-resnet-50-network
(2)在matlab打开下载的resnet50.mlpkginstall文件
(3)使用下面代码进行测试,出现结果说明安装成功
clear
clc
% Access the trained model
net = resnet50();
% See details of the architecture
net.Layers
% Read the image to classify
I = imread('peppers.png');
% Adjust size of the image
sz = net.Layers(1).InputSize;
I = I(1:sz(1),1:sz(2),1:sz(3));
% Classify the image using Resnet-50
label = classify(net, I);
% Show the image and the classification results
figure
imshow(I)
text(10,20,char(label),'Color','white')
二、训练猫狗数据集
(1)数据集下载链接:
https://pan.quark.cn/s/e043408353a5
(2)将数据集按照如下目录进行放置
(3)生成预训练模型
在命令行窗口输入 deepNetworkDesigner(resnet50)
然后点击导出→使用初始参数生成代码
保存生成的网络初始化参数,生成的mlx文件可以叉掉:
修改文件路径,类别数目以及相关参数:
clear
clc
filename = "datasets";
%% 加载用于网络初始化的参数。对于迁移学习,网络初始化参数是初始预训练网络的参数。
trainingSetup = load("resnet-50.mat");
%% 设置图像文件夹路径和标签
nc = 2; %类别
imdsTrain = imageDatastore(filename,"IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain, imdsValidation] = splitEachLabel(imdsTrain,0.8); % 80的训练集
%% 调整图像大小以匹配网络输入层。
augimdsTrain = augmentedImageDatastore([224 224 3],imdsTrain);
augimdsValidation = augmentedImageDatastore([224 224 3],imdsValidation);
%% 设置训练选项
opts = trainingOptions("sgdm",...
"ExecutionEnvironment","gpu",...
"InitialLearnRate",0.01,...
"MaxEpochs",20,...
"MiniBatchSize",64,...
"Shuffle","every-epoch",...
"Plots","training-progress",...
"ValidationData",augimdsValidation);
三、训练及测试结果
(1)训练结果
(2)导入一张图片进行测试
clear
clc
load result\net.mat
load result\traininfo.mat
%% 随便选一张进行测试
[file,path] = uigetfile('*.jpg');
if isequal(file,0)
disp('User selected Cancel');
else
filename = fullfile(path,file);
end
I = imread(filename);
I = imresize(I, [224 224]);
[YPred,probs] = classify(net,I);
imshow(I)
label = YPred;
title(string(label) + ", " + num2str(100*max(probs),3) + "%");
四、完整代码获取(链接文末)
MATLAB卷积神经网络——基于ResNet-50进行图像分类
如需绘制混淆矩阵图、输出单类别的准确度等等...,可私聊小编,为你量身进行定制。关注公众号,每日更新更多精彩内容!!!
最后:
如果你想要进一步了解更多的相关知识,可以关注下面公众号联系~会不定期发布相关设计内容包括但不限于如下内容:信号处理、通信仿真、算法设计、matlab appdesigner,gui设计、simulink仿真......希望能帮到你!