度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码

news2025/2/13 13:03:52

1.数据集简介、训练集与测试集划分
2.模型相关知识
3.model.py——定义ResNet50网络模型
4.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数
5.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

一、数据集简介

1.自建数据文件夹

首先确定这次分类种类,采用爬虫、官网数据集和自己拍照的照片获取5类,新建个文件夹data,里面包含5个文件夹,文件夹名字取种类英文,每个文件夹照片数量最好一样多,五百多张以上。如我选了雏菊,蒲公英,玫瑰,向日葵,郁金香5类,如下图,每种类型有600~900张图像。如下图

花数据集下载链接https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
在这里插入图片描述
2.划分训练集与测试集

这是划分数据集代码,同一目录下运,复制改文件夹路径。

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'data/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")

二、模型相关知识

之前有文章介绍模型,如果不清楚可以点下链接转过去学习。

深度学习卷积神经网络CNN之ResNet模型网络详解说明(超详细理论篇)

在这里插入图片描述

三、model.py——定义ResNet50网络模型

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

四、model.py——定义ResNet34网络模型

batch_size = 16
epochs = 5

import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet50


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "zjdata", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=0)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = resnet50(num_classes=5, include_top=True)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.1)

    epochs = 5
    best_acc = 0.0
    save_path = './resNet50.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

训练中截图
在这里插入图片描述

五、predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "./1.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = resnet34(num_classes=5).to(device)

    # load model weights
    weights_path = "./resNet50.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/639477.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为Nomad Web使用添加快捷应用图标

大家好,才是真的好。 上次我们说到Nomad Web已经更新到了1.0.8版本,作为未来的“Notes客户机”(运行在浏览器中),Nomad Web的更新迭代很快。从1.0.5版本开始,就可以直接部署在Domino服务器上,而…

STM32MP157_PRO开发板的第一个驱动程序

文章目录 目的:为什么编译驱动程序之前要先编译内核?编译内核编译设备树编译安装内核模块编译内核模块安装内核模块到 Ubuntu 的NFS目录下备用 安装内核和模块到开发板上编译 led 驱动在开发板安装驱动模块下载驱动程序安装驱动模块 目的: 在…

Netty实战(十三)

WebSocket协议(一) 一、什么是WebSocket 协议二、简单的 WebSocket 程序示例2.1 程序逻辑2.2 添加 WebSocket 支持2.3 处理 HTTP 请求2.4 处理 WebSocket 帧 一、什么是WebSocket 协议 WebSocket 协议是完全重新设计的协议,旨在为 Web 上的双…

读书笔记-《ON JAVA 中文版》-摘要16[第十六章 代码校验]

文章目录 第十六章 代码校验1. 测试1.1 单元测试1.2 JUnit1.3 测试覆盖率的幻觉 2. 前置条件2.1 断言(Assertions)2.2 Java 断言语法2.3 Guava 断言2.4 使用断言进行契约式设计2.4.1 检查指令2.4.2 前置条件2.4.3 后置条件2.4.4 不变性2.4.5 放松 DbC 检…

Frida技术—逆向开发的屠龙刀

简介 Frida是一种基于JavaScript的动态分析工具,可以用于逆向开发、应用程序的安全测试、反欺诈技术等领域。Frida主要用于在已安装的应用程序上运行自己的JavaScript代码,从而进行动态分析、调试、修改等操作,能够绕过应用程序的安全措施&a…

路径规划算法:基于人工电场优化的路径规划算法- 附代码

路径规划算法:基于人工电场优化的路径规划算法- 附代码 文章目录 路径规划算法:基于人工电场优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要:本文主要介绍利用智能优化…

【Leetcode60天带刷】day06哈希表——242.有效的字母异位词,349. 两个数组的交集,202题. 快乐数,1. 两数之和

题目:242.有效的字母异位词 Leetcode原题链接:242. 有效的字母异位词 思考历程与知识点: 如果一个字母一个字母的找,也就是暴力,用两个for的话时间复杂度是O(N^2); 我们可以换个思路,a~z一共…

Telerik Report Server R2 2023

Telerik Report Server R2 2023 仪表报告项-使用仪表或类似表盘的显示提供数据的可视化表示。 报告项上的AccessibleRole属性-ARIA(可访问的富Internet应用程序)支持已显著改进。在Web上,当启用了辅助功能时,呈现的报表项包含预定义的辅助功能角色。这样…

(七)矢量数据的空间分析——叠置分析①

矢量数据的空间分析——叠置分析 叠置分析是将代表不同主题的各个数据层面进行叠置,产生一个新的数据层面,叠置结果综合了原来两个或多个层面要素所具有的属性。 叠置分析不仅生成了新的空间关系,而且还将输入的多个数据层的属性联系起来产…

随机的乐趣和游戏

1、猜数字游戏 #GuessingGame.py import random the_number random.randint(1, 10) print("计算机已经在1到10之间随机生成了一个数字,") guess int(input("请你猜猜是哪一个数字: ")) while guess ! the_number:if guess > the_number:p…

【MySQL】数据库基本知识小结

数据库的基本概念 数据库:DataBase 简称 DB,就是信息的集合或者说数据库是由数据库管理系统管理的数据的集合。数据库管理系统:DataBase Management System 简称 DBMS,是一种操纵和管理数据库的大型软件,通常用于建立…

数据结构 一绪论

第一章:绪论 1.1数据结构的基本概念 1.数据:数据是信息的载体,是描述客观事物属性的数、字符以及所有能输入到计算机中并被程序识别 和处理的符号的集合。 2.数据元素:数据元素是数据的基本单位,通常作为一个整体进行…

软件项目质量跟踪控制的3大方法

1、质量度量法 质量度量法包括尺度度量和二元度量两种方法,而尺度度量是定量度量,适用可直接度量的特性。如缺陷率 而二元度量是定性度量,适用间接度量的质量特性。如使用性,灵活性。 软件项目质量跟踪控制的3大方法:质…

ReentrantLock实现原理-公平锁

在ReentrantLock实现原理(1)一节中,我们了解了ReentrantLock非公平锁的获取流程,在本节中我们来看下ReentrantLock公平锁的创建以及锁管理流程 创建ReentrantLock公平锁 创建公平锁代码如下: ReentrantLock reentrantLock new ReentrantL…

elementui 自定义loading动画加载层

elementui 自定义loading动画加载层。main.js中添加 import { Loading } from element-ui /* 自定义加载层 */ Vue.prototype.openLoading function(wer) {const loading Loading.service({lock: true, // 是否锁屏text: , // 加载动画的文字// spinner: inner-circles-loade…

数据库期末复习(8)并发控制

笔记 数据库DBMS并发控制(1)_旅僧的博客-CSDN博客 数据库 并发控制(2)死锁和意向锁_旅僧的博客-CSDN博客 冲突可串行化和锁 怎么判断是否可以进行冲突可串行化:简便的方法是优先图 只有不同对象和同一对象都是读才不能发生非串行化调度。我真傻 两个节点其实也可以算是一…

粤海街道后海村城市更新单元旧改项目,规划建面约42.6万平

项目名称:粤海街道后海村城市更新单元 项目地址:南山区粤海街道,西临天后西路,东至后海大道,南近东滨路,北近创业路。 开发商:远洋集团 申报主体:深圳市蛇口后海实业股份公司 拆…

SEMICON China 2023| 加速科技将携全系新品重磅亮相,欢迎打卡加速科技展台

2023年6月29日-7月1日,全球规模最大、规格最高的半导体行业盛会—SEMICON China 2023将在上海新国际博览中心盛大举行。作为业内领先的半导体测试设备供应商,杭州加速科技将携全系重磅新品及全系列测试解决方案受邀参展。 展位信息:E5馆 5643…

pytest+allure+jenkins持续集成及生成测试报告

目录 前言 一、jenkins安装 二、插件安装 三、构建项目 四、查看运行结果 总结: 前言 前面,讲了“Pycharmpytestallure打造高逼格的测试报告”,但实际工作中,往往需要通过jenkins进行自动化测试用例的持续集成并自动生成测试…

KW 喜报 | KaiwuDB 斩获 2023 数博会“优秀科技成果”奖

5月26日,大数据领域的国家级盛会——2023 中国国际大数据产业博览会(以下简称“2023 数博会”)在贵阳盛大开幕。作为大会最重磅的环节之一,“2023 领先科技成果发布会”于数博发布中心场地举办,向全行业发布 70 余项兼…