【MATLAB第49期】基于MATLAB的深度学习ResNet-18网络不平衡图像数据分类识别模型
一、基本介绍
这篇文章展示了如何使用不平衡训练数据集对图像进行分类,其中每个类的图像数量在类之间不同。两种最流行的解决方案是down-sampling降采样和over-sampling过采样。
在降采样中,每个类别的图像数量减少到所有类别中的最小图像数量。降采样的实现很容易:只需使用splitEachLabel函数并指定类的最小数量,
另一方面,当执行过采样时,每个类别的图像数量增加。这两种策略对于不平衡的数据集都是有效的。然而,过采样需要更复杂的过程。
本篇文章采用过采样平衡数据 。
Label Count
_____________ _____
caesar_salad 13
caprese_salad 8
french_fries 91
greek_salad 12
hamburger 119
hot_dog 16
pizza 150
sashimi 20
sushi 62
过采样结果:
Label Count
_____________ _____
caesar_salad 150
caprese_salad 150
french_fries 150
greek_salad 150
hamburger 150
hot_dog 150
pizza 150
sashimi 150
sushi 150
二、数据情况
食品图像数据集包含九类食物的978张照片(ceaser_salad、caprese_salad,french_fries、greek_saland、汉堡包、hot_dog、披萨、生鱼片和寿司)。
数据集可在下列地址下载
https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip
本文为了提高运行速度 ,选择80%训练, 10%验证,10%测试。
三、代码展示
1.导入数据
imds = imageDatastore('ExampleFoodImageDataset');
2.图像数据展示
numExample=16;
idx = randperm(numel(imds.Files),numExample);
for i=1:numExample
I=readimage(imds,idx(i));
I_tile{i}=insertText(I,[1,1],string(imds.Labels(idx(i))),'FontSize',20);
end
% use imtile function to tile out the example images
I_tile = imtile(I_tile);
figure()
imshow(I_tile);title('examples of the dataset')
3.数据集划分 (训练80%,验证10%,测试10%)
[imdsTrain, imdsValid,imdsTest]=splitEachLabel(imds,0.8,0.1,0.1);
4.选取最大样本数
PerClass是所有类中的最大样本数。
PerClass = max(numObservations);
5.平衡数据
randReplicateFiles是一个仅对文件进行混洗的支持功能。
要选择的图像数量由PerClass定义。从数据库中找到不同类别的图像目录,然后随机复制对应的图像至对应的数量,以平衡类中的图像数量。
files = splitapply(@(x){randReplicateFiles(x,desiredNumObservationsPerClass)},imdsTrain.Files,G);
6.构建网络
加载预先训练的模型,ResNet-18
net = resnet18;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
learnableLayer='fc1000';
classLayer='ClassificationLayer_predictions';
7.图像增强
定义图像增强器
pixelRange = [-30 30];
RotationRange = [-30 30];
scaleRange = [0.8 1.2];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange, ...
'RandXScale',scaleRange, ...
'RandYScale',scaleRange, ...
'RandRotation',RotationRange ...
);
8.设置网络参数
miniBatchSize = 64;
valFrequency = max(floor(numel(augimdsTest.Files)/miniBatchSize)*10,1);
options = trainingOptions('sgdm', ...
'MiniBatchSize',miniBatchSize, ...
'MaxEpochs',5, ...%30
'InitialLearnRate',1e-2, ...%3e-4
'Shuffle','every-epoch', ...
'ValidationData',augimdsValid, ...
'ValidationFrequency',valFrequency, ...
'Verbose',false, ...
'Plots','training-progress');
9.训练网络
net = trainNetwork(augimdsTrain,lgraph,options);
10.分类评估
[YPred,probs] = classify(net,augimdsTest);
accuracy = mean(YPred == imdsTest.Labels)
YValidation = imdsTest.Labels;
YTrue=imdsTest.Labels;
figure;cm=confusionchart(YTrue,YPred);
%当我运行这个代码时,主要的错误分类是生鱼片和寿司,
%它们看起来很相似。请尝试使用此代码进行过度采样,并希望它对您的工作有所帮助。
四、运行效果
五、代码获取
后台私信回复“49期”,即可获取下载链接。