深度学习pytorch实战五:基于ResNet34迁移学习的方法图像分类篇自建花数据集图像分类(5类)超详细代码

news2024/11/13 9:40:03

1.数据集简介
2.模型相关知识
3.split_data.py——训练集与测试集划分
4.model.py——定义ResNet34网络模型
5.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数
6.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

一、数据集简介

1.自建数据文件夹

首先确定这次分类种类,采用爬虫、官网数据集和自己拍照的照片获取5类,新建个文件夹data,里面包含5个文件夹,文件夹名字取种类英文,每个文件夹照片数量最好一样多,五百多张以上。如我选了雏菊,蒲公英,玫瑰,向日葵,郁金香5类,如下图,每种类型有600~900张图像。如下图

花数据集下载链接https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
在这里插入图片描述
2.划分训练集与测试集

这是划分数据集代码,同一目录下运,复制改文件夹路径。

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'data/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")

二、模型相关知识

之前有文章介绍模型,如果不清楚可以点下链接转过去学习。

深度学习卷积神经网络CNN之ResNet模型网络详解说明(超详细理论篇)

在这里插入图片描述

三、model.py——定义ResNet34网络模型

这里还是直接复制给出原模型,不用改参数。模型包含34、50、101模型

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

四、train.py——训练,计算损失值loss,计算accuracy,保存训练好的网络参数

第一步,提前下载权重链接,复制链接网址打开直接下载,下载完,放在同一个工程文件夹,记得修改个名字,后面要用。

ResNet34权重链接https://download.pytorch.org/models/resnet34-333f7ec4.pth

第二步 71行类数、63行之前下载权重文件名字、83行保存最终权重文件名字

net.fc = nn.Linear(in_channel, 5)//修改5类的5
model_weight_path = "./resnet34-pre.pth"
save_path = './resNext34.pth'

其他参数bach_size=16;(根据cpu或GPU性能选择32、64等)
学习率 0.01
epoch 5

import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34,resnet101


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "zjdata", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    for param in net.parameters():
        param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.01)

    epochs = 5
    best_acc = 0.0
    save_path = './resNext34.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

训练开始截图,我是用CPU训练
在这里插入图片描述

六、predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

注意图片位置和权重参数名字

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "./1.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = resnet34(num_classes=5).to(device)

    # load model weights
    weights_path = "./resNext34.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

预测结果截图
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/638985.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT 应用——使用 chatGPT 写高考作文

写作文,很简单,但写一篇好的作文,是非常有难度的。 想要写一篇高分作文,需要对作文题目有正确的理解,需要展现独到的观点和深入的思考,需要具备清晰的逻辑结构,需要准确而得体的语言表达。 正…

Python程序设计基础:数值

文章目录 一、数值数据类型二、python内置的数值操作三、math库 一、数值数据类型 Python语言可以很方便的用于处理数值运算问题,在数值运算过程中,常见的额两种数据类型分别为整数类型(int)和浮点类型(float&#xf…

【论文阅读】(2023.06.09-2023.06.18)论文阅读简单记录和汇总

(2023.06.09-2023.06.12)论文阅读简单记录和汇总 2023/06/09:虽然下周是我做汇报,但是到了周末该打游戏还是得打的 2023/06/12:好累好困,现在好容易累。 目录 (TCSVT 2023)Facial Image Compression via …

2021-07-12

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

VMware12虚拟机安装macoxs10.14

按照以下教程操作 VMware12虚拟机安装Mac系统虚拟机与本机传输文件虚拟机联本机网络教程【图文详解】_vmware12 mac_扑在代码上的高尔基的博客-CSDN博客 补充,我操作时出现的状况 1、虚拟机的macos补丁 unlocker207.7z文件需要放在全英文路径下运行,…

flutter:文件系统目录、文件读写

参考 参考:老孟 文件存储和网络请求 数据存储 Dart的 IO 库包含了文件读写的相关类,它属于 Dart 语法标准的一部分,所以通过 Dart IO 库,无论是 Dart VM 下的脚本还是 Flutter,都是通过 Dart IO 库来操作文件的。但…

H桥级联型五电平逆变器MATLAB仿真模型

H桥级联型五电平逆变器MATLAB仿真模型资源-CSDN文库https://download.csdn.net/download/weixin_56691527/87899094 模型简介: MATLAB21b版本 逆变器采用H桥级联的形式连接,加设LCL滤波器,三相负载构成主电路。 采用SPWM调制,可…

解密混沌工程——混沌工程价值

在数字化转型、十四五规划的大背景 下,大规模上云、分布式的核心改造等“云化”逐渐走进企业。 但是,云化的发展,使企业系统的复杂度呈指数级增长,故障越来越多。 企业在数字化转型中拥抱云计算、 信创国产化、分布式核心等新技…

【软件测试】性能测试的9大误区,带你少走弯路...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 误区1&#xff1a…

外卖点餐APP小程序系统开发功能有哪些?

外卖点餐APP小程序系统开发功能有哪些? 1、餐饮分类。点餐外卖小程序需要提供餐厅列表,以便于用户浏览餐厅信息。用户可以根据不同的筛选条件,如菜系、评价、地理位置等来选择自己想要的餐厅。 2、美食列表。针对特定餐厅&…

pandas---算术运算、逻辑运算、统计运算、自定义运算

1. Series的运算 适用于NumPy的数组运算也适用于Series。 # 基本算术运算 s 100 s - 100 s * 100 s / 100 s // 2 s ** 2 s % 2 Series之间的运算: 在运算中自动对齐索引;如果索引不对应,则补NaN;Series没有广播机制。 s3 …

Makerbase VESC 75200 AS5047P编码器测试

Makerbase VESC 75200 AS5047P编码器测试 提示:MKESC75200主板推荐使用VESC TOOL V3.0,固件版本V5.2。 第一部分 硬件介绍 1.1 硬件接口 1.2 硬件清单 序号品名数量1MKSESC 75200 V2.0主板122808电机(带AS5047编码器)13DC 24电…

网络安全合规-安全合规地图

网络安全行业国内本质上是合规驱动,但对合规的理解和尺度,甲方与甲方、甲方与乙方之间就有很大差别。 甲方通常来说都具有一定合规方面的积累,包括体系、制度、结构以及职责等等,如果有新的监管要求或是强制性标准发布&#xff0c…

HarmonyOS学习路之开发篇—Java UI框架(DependentLayout)

DependentLayout DependentLayout是Java UI框架里的一种常见布局。与DirectionalLayout相比,拥有更多的排布方式,每个组件可以指定相对于其他同级元素的位置,或者指定相对于父组件的位置。 支持的XML属性DependentLayout的共有XML属性继承自…

ASEMI代理艾赛斯IXFA14N85XHV功率MOSFET综合指南

编辑-Z 在当今世界,电力电子在各种应用中发挥着至关重要的作用,从电源和电机驱动到电动汽车和可再生能源系统。这些应用中的关键部件之一是功率MOSFET(金属氧化物半导体场效应晶体管)。IXFA14N85XHV是一款先进的功率MOSFET&#…

互联网医院app开发|互联网医院小程序开发

随着人们健康意识的增强,互联网医院系统已经成为了现代医疗服务的重要组成部分。互联网医院系统是指利用信息技术手段,建立具有医疗服务能力、具有与普通医院异地协同和医联体功能的综合性医疗服务平台。它可以为患者提供一系列完善的医疗服务&#xff0…

Vue3-03-Vue2 响应式 VS Vue3 响应式

本文来讲解从 Vue2 到 Vue3 响应式底层的一些改变。 前言 Vue 2.x 为什么不监听数组下标索引值的变化? 参考了很多博主的推文,自己也尝试了一下,Object.defineProperty 是可以做到监听数组的索引值的变化的,来做 getter 和 sette…

基于Yolov8的道路破损检测系统

目录 1.Yolov8介绍 2.数据集介绍 2.1数据集划分 2.2 通过voc_label.py得到适合yolov8训练需要的 2.3生成内容如下 3.训练结果分析 4. 道路破损检测系统设计 4.1 PySide6介绍 4.2 安装PySide6 4.3 道路破损检测系统设计 1.Yolov8介绍 Ultralytics YOLOv8是Ultralytics公司…

Sui x KuCoin Labs夏季黑客松|本周Workshop预告

由Sui和KuCoin Labs联合主办的夏季黑客松正如火如荼的报名中。了解黑客松详情:Sui与KuCoin Labs联合推出夏季黑客松,奖池高达28.5万美金! 👉 黑客松官方网站:Sui x KuCoin Labs Summer Hackathon | Sui x KuCoin Labs…

1.2 几种常用的数制

学习目标: 学习几种常用的数制可以通过以下步骤进行: 1. 确定目标数制:常用的数制包括十进制、二进制、八进制和十六进制。首先,确定你想要学习的数制是哪一种。 2. 理解基本概念:了解每种数制的基本概念是非常重要…