PyTorch之ResNet101模型与示例

news2024/9/20 22:29:42

【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客

ResNet101模型

ResNet101是一种深度残差网络,它是ResNet系列中的一种,下面详解ResNet101网络结构。

ResNet101网络结构中有101层,其中第一层是7×7的卷积层,然后是4个阶段(Stage),每个阶段包含若干残差块(Residual Block)。之后是全局平均池化(Global Average Pooling)层以及全连接层(Fully Connected Layer)。全连接层的作用是将全局平均池化层的输出展开成一个向量,并通过一个全连接层将其映射到类别数量的维度上。

ResNet101的每个残差块由两个3×3的卷积层组成,每个卷积层后面都跟有批量归一化(Batch Normalization)和ReLU激活函数。在残差块之间也有批量归一化和ReLU激活函数,但没有卷积层。每个阶段的第一个残差块使用1×1的卷积层将输入的通道数转换为输出的通道数,以便与后续的残差块进行加和操作。

ResNet101模型的主要贡献是引入了残差块的概念,使得网络可以更深,更容易训练。它在ImageNet数据集上的表现非常出色,达到了当时的最优水平。

实战项目代码分析

本项目使用的数据集是一个猴痘病毒分类的数据集,包含猴痘和其他病毒两类样本。数据集划分为训练集和验证集,猴痘类别的图片位于本书配套源码包的monkeypox目录下,其他病毒类别图片位于Others目录下。

本项目借鉴了迁移学习的思想,使用了在 ImageNet 上训练的 ResNet101 网络模型,ImageNet是供计算机视觉识别研究的大型可视化图像数据集,其中包含超过140万手动标注的图像数据,并包含1 000个图像类别,经过预训练的ResNet101网络的全连接层输出1 000个节点,在本实验中为了适应猴痘病数据集类别数,将节点数量由1 000改为2。但是注意需要冻结其他层的参数,防止训练过程中将其进行改动,然后训练微调最后一层即可。该方法既能提高模型的泛化能力和鲁棒性,也能够减少训练的时间,节约算力的开销。

我们知道损失函数是将随机事件或其有关随机变量的取值映射为非负实数,表示该随机事件的风险或损失的函数,在实际任务中则通过最小化损失函数求解和评估模型。本项目使用交叉熵损失表达预测值和真实值的不一致程度,交叉熵损失常用于在图像识别任务中作为损失函数,能够有效地衡量同一个随机变量中的两个不同概率分布的差异程度。

深度学习是以最小化损失函数为目标,其本质上是一种优化问题,目前应用于深度学习的优化算法均是由梯度下降算法发展而来的,其主要思想为利用链式求导法则计算损失函数值相对于神经网络中的每一个权重参数的梯度,通过更新权重参数达到降低损失函数值的效果。本项目使用的优化器Adam算法是一种基于梯度下降的优化算法。Adam算法的优点是收敛速度快,不需要手动调整学习率,兼顾了稳定性和速度。

我们使用PyTorch来搭建猴痘病毒识别模型,完整代码如下:

###############monkeypox.py#######
import torchvision
from torch import nn
import os
import pickle
import torch
from torchvision import transforms, datasets
from tqdm import tqdm
from PIL import Image

import matplotlib.pyplot as plt

epochs = 10
lr = 0.03
batch_size = 32
image_path = './monkeypoxdata'
model_path = './chk/resnet101-cd907fc2.pth'
save_path = './chk/monkeypox_model.pkl'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 1.数据转换
data_transform = {
    # 训练中的数据增强和归一化
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),  # 随机裁剪
        transforms.RandomHorizontalFlip(),  # 左右翻转
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
        # 均值方差归一化
    ]),
    # 验证集不增强,仅进行归一化
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
}

# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'),
                                     transform=data_transform['train'])

# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size,
                                           True)

print('using {} images for training.'.format(len(train_dataset)))

# 4.建立分类标签与索引的关系
cloth_list = train_dataset.class_to_idx
class_dict = {}
for key, val in cloth_list.items():
    class_dict[val] = key
with open('class_dict.pk', 'wb') as f:
    pickle.dump(class_dict, f)

# 5.加载ResNet101模型
model = torchvision.models.resnet101(
    weights=torchvision.models.ResNet101_Weights.DEFAULT)
# 加载预训练好的ResNet模型
model.load_state_dict(torch.load(model_path, 'cpu'))
# 冻结模型参数
for param in model.parameters():
    param.requires_grad = False

# 修改最后一层的全连接层
model.fc = nn.Linear(model.fc.in_features, 2)

# 将模型加载到cpu中
model = model.to(device)

criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 优化器

# 6.模型训练
best_acc = 0  			# 最优精确率
best_model = None  		# 最优模型参数

for epoch in range(epochs):
    model.train()
    running_loss = 0  	# 损失
    epoch_acc = 0  		# 每个epoch的准确率
    epoch_acc_count = 0 	# 每个epoch训练的样本数
    train_count = 0  	# 用于计算总的样本数,方便求准确率
    train_bar = tqdm(train_loader)
    for data in train_bar:
        images, labels = data
        optimizer.zero_grad()
        output = model(images.to(device))
        loss = criterion(output, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)
        # 计算每个epoch正确的个数
        epoch_acc_count += (output.argmax(axis=1) == labels.view(-1)).sum()
        train_count += len(images)

    # 每个epoch对应的准确率
    epoch_acc = epoch_acc_count / train_count

    # 打印信息
    print("【EPOCH: 】%s" % str(epoch + 1))
    print("训练损失为%s" % str(running_loss))
    print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')

    if epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model = model.state_dict()

    # 在训练结束保存最优的模型参数
    if epoch == epochs - 1:
        # 保存模型
        torch.save(best_model, save_path)

print('Finished Training')

# 加载索引与标签映射字典
with open('class_dict.pk', 'rb') as f:
    class_dict = pickle.load(f)

# 数据变换
data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406],
                          [0.229, 0.224, 0.225])])

# 图片路径
img_path = r'./monkeypoxdata/test/test_01.jpg'

# 打开图像
img = Image.open(img_path)

# 对图像进行变换
img = data_transform(img)

plt.imshow(img.permute(1, 2, 0))
plt.show()

# 将图像升维,增加batch_size维度
img = torch.unsqueeze(img, dim=0)

# 获取预测结果
pred = class_dict[model(img).argmax(axis=1).item()]
print('【预测结果分类】:%s' % pred)

运行结果如下:

using 2142 images for training.
train epoch[1/10] loss:0.882: 100%|██████████| 67/67 [05:46<00:00,  5.17s/it]
【EPOCH: 】1
训练损失为38.31112961471081
训练精度为73.90%
train epoch[2/10] loss:0.460: 100%|██████████| 67/67 [06:33<00:00,  5.88s/it]
【EPOCH: 】2
训练损失为37.73484416306019
训练精度为77.40%
train epoch[3/10] loss:0.225: 100%|██████████| 67/67 [06:00<00:00,  5.38s/it]
  0%|          | 0/67 [00:00<?, ?it/s]【EPOCH: 】3
训练损失为31.319448485970497
训练精度为80.06%
train epoch[4/10] loss:0.490: 100%|██████████| 67/67 [06:41<00:00,  6.00s/it]
  0%|          | 0/67 [00:00<?, ?it/s]【EPOCH: 】4
训练损失为36.781765565276146
训练精度为78.94%
train epoch[5/10] loss:0.440: 100%|██████████| 67/67 [06:16<00:00,  5.62s/it]
  0%|          | 0/67 [00:00<?, ?it/s]【EPOCH: 】5
训练损失为29.949161008000374
训练精度为81.93%
train epoch[6/10] loss:0.253: 100%|██████████| 67/67 [06:17<00:00,  5.63s/it]
【EPOCH: 】6
训练损失为27.939718201756477
训练精度为82.63%
train epoch[7/10] loss:0.341: 100%|██████████| 67/67 [06:25<00:00,  5.75s/it]
【EPOCH: 】7
训练损失为29.68729281425476
训练精度为82.77%
train epoch[8/10] loss:0.337: 100%|██████████| 67/67 [06:57<00:00,  6.22s/it]
【EPOCH: 】8
训练损失为28.97513736784458
训练精度为82.77%
train epoch[9/10] loss:0.089: 100%|██████████| 67/67 [06:17<00:00,  5.63s/it]
  0%|          | 0/67 [00:00<?, ?it/s]【EPOCH: 】9
训练损失为26.791129417717457
训练精度为83.05%
train epoch[10/10] loss:0.625: 100%|██████████| 67/67 [06:06<00:00,  5.46s/it]
【EPOCH: 】10
训练损失为33.004408583045006
训练精度为80.85%
Finished Training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

如图16-3所示,预测结果分类为Monkeypox。

图16-3

这个项目能够使有相关症状的感染者有效地识别出是否为猴痘病,提出了一种改进的基于迁移学习残差网络的图像自动识别方法。该方法使用了ResNet101网络并使用该网络预训练权重进行迁移学习,在猴痘病数据集上进行了网络的训练,可以增加训练轮次,最终达到了比较高的识别准确率。

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1947468.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nacos 配置中心配置加载源码分析

前言&#xff1a;上一篇我们分析 Nacos 配置中心服务端源码的时候&#xff0c;多次看到有去读取本地配置文件&#xff0c;那本地配置文件是何时加载的&#xff1f;本篇我们来进行详细分析。 Nacos 系列文章传送门&#xff1a; Nacos 初步认识和 Nacos 部署细节 Nacos 配置管…

https改造-python https 改造

文章目录 前言https改造-python https 改造1.1. https 配置信任库2. 客户端带证书https发送,、服务端关闭主机、ip验证 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每…

遗传算法与深度学习实战——进化深度学习

遗传算法与深度学习实战——进化深度学习 0. 前言1. 进化深度学习1.1 进化深度学习简介1.2 进化计算简介 2. 进化深度学习应用场景3. 深度学习优化3.1 优化网络体系结构 4. 通过自动机器学习进行优化4.1 自动机器学习简介4.2 AutoML 工具 5. 进化深度学习应用5.1 模型选择&…

Java给定一些元素随机从中选择一个

文章目录 代码实现java.util.Random类实现随机取数(推荐)java.util.Collections实现(推荐)Java 8 Stream流实现(不推荐) 完整代码参考&#xff08;含测试数据&#xff09; 在Java中&#xff0c;要从给定的数据集合中随机选择一个元素&#xff0c;我们很容易想到可以使用 java.…

【Stable Diffusion】(基础篇四)—— 模型

模型 本系列博客笔记主要参考B站nenly同学的视频教程&#xff0c;传送门&#xff1a;B站第一套系统的AI绘画课&#xff01;零基础学会Stable Diffusion&#xff0c;这绝对是你看过的最容易上手的AI绘画教程 | SD WebUI 保姆级攻略_哔哩哔哩_bilibili 本文主要讲解如何下载和使…

C++【泛型编程】【string类常用接口】学习

目录 泛型编程 推演实例化 显示实例化 类模板 类模板的声明和定义分离 STL string string的构造和拷贝构造 选取特定字符串拷贝 解析&#xff1a; 关于npos的解析 验证 从一个字符串中拷贝前几个字符 解析&#xff1a; 注意&#xff1a; 验证&#xff1a; size…

AI应用行业落地100例 | 移民公司Envoy Global引入AI员工赋能,效率飙升80%,开启服务新篇章

《AI应用行业落地100例》专题汇集了人工智能技术在金融、医疗、教育、制造等多个关键行业中的100个实际应用案例&#xff0c;深入剖析了AI如何助力行业创新、提升效率&#xff0c;并预测了技术发展趋势&#xff0c;旨在为行业决策者和创新者提供宝贵的洞察和启发。 Envoy Globa…

Pytorch使用教学2-Tensor的维度

在PyTorch使用的过程中&#xff0c;维度转换一定少不了。而PyTorch中有多种维度形变的方法&#xff0c;我们该在什么场景下使用什么方法呢&#xff1f; 本小节我们使用的张量如下&#xff1a; # 一维向量 t1 torch.tensor((1, 2)) # 二维向量 t2 torch.tensor([[1, 2, 3], …

【Unity PC端打包exe封装一个并添加安装引导】

Unity PC端打包exe封装一个并添加安装引导 比特虫在线制作ico图标ico图标转换工具 选中打包出来的所有文件和ico图标 右键 使用RAR软件 添加到压缩文件 两个名称要相同 设置完点击确认等待压缩完成 然后就可以使用 Smart Install Maker制作引导安装程序了

Matlab进阶绘图第64期—三维分组针状图

三维分组针状图可以看作是三维分组散点图的升级&#xff0c;能够直观地展示各组分、各元素的位置、对比情况。 由于Matlab中未收录三维分组针状图的绘制函数&#xff0c;因此需要大家自行设法解决。 本文使用自制的groupedstem3小工具进行三维分组针状图的绘制&#xff0c;先…

数据结构之深入理解简单选择排序:原理、实现与示例(C,C++)

文章目录 一、简单选择排序原理二、C/C代码实现总结&#xff1a; 在计算机科学中&#xff0c;排序算法是一种非常基础且重要的算法。简单选择排序&#xff08;Selection Sort&#xff09;作为其中的一种&#xff0c;因其实现简单、易于理解而受到许多初学者的喜爱。本文将详细介…

Maven概述

目录 1.Maven简介 2.Maven开发环境搭建 2.1下载Maven服务器 2.2安装&#xff0c;配置Maven 1.配置本地仓库地址 2.配置阿里云镜像地址 2.3在idea中配置maven 2.4在idea中创建maven项目 3.pom.xml配置 1.项目基本信息 2.依赖信息 3.构建信息 4.Maven命令 5.打包Jav…

华杉研发九学习日记17 正则表达式 异常

华杉研发九学习日记17 一&#xff0c;正则表达式 ^ $ 作用&#xff1a; 测试字符串内的模式(匹配) 例如&#xff0c;可以测试输入字符串&#xff0c;以查看字符串内是否出现电话号码模式或信用卡号码模式。这称为数据验证. 替换文本&#xff08;替换》 可以使用正则表达式来…

知识工程经典语言 PROLOG基本介绍

定义 PROLOG语言是一种基于Horn子句的逻辑型程序设计语言&#xff0c;也是一种陈述性语言。 PROLOG的语句 PROLOG语言仅有三种语句&#xff0c;称为事实、规则和问题。 事实 格式 <谓词名>(<项表>). 其中谓词名是以小写英文字母开头的字母、数字、下划线等组成的…

使用js实现常见的数据结构---链表,队列,栈,树

注&#xff1a;本文只作为数据结构的实现参考和个人理解 链表 链表是由多个节点&#xff08;node&#xff09;连接起来的&#xff0c;每个节点包含了一些存储的数据和指向下一个节点的指针&#xff0c; 链表&#xff1a;多个连续的节点构成&#xff0c;节点&#xff1a;包含一…

spring-boot3.x整合Swagger 3 (OpenAPI 3) +knife4j

1.简介 OpenAPI阶段的Swagger也被称为Swagger 3.0。在Swagger 2.0后&#xff0c;Swagger规范正式更名为OpenAPI规范&#xff0c;并且根据OpenAPI规范的版本号进行了更新。因此&#xff0c;Swagger 3.0对应的就是OpenAPI 3.0版本&#xff0c;它是Swagger在OpenAPI阶段推出的一个…

大数据-47 Redis 缓存过期 淘汰删除策略 LRU LFU 基础概念

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…

试过可道云teamOS的权限管理,才知道团队协作可以这么顺

在快节奏的工作环境中&#xff0c;团队协作的顺畅与否往往决定了项目的成败。作为团队中的一员&#xff0c;我深知权限管理在团队协作中的重要性。 我们的团队在协作过程中总是被权限问题所困扰。文件共享、资料访问、任务分配……每一个环节都需要小心翼翼地处理权限设置&…

学术研讨 | 区块链与隐私计算领域专用硬件研讨会顺利召开

学术研讨 近日&#xff0c;国家区块链技术创新中心主办&#xff0c;长安链开源社区支持的“区块链与隐私计算领域专用硬件研讨会”顺利召开&#xff0c;会议围绕基于区块链与隐私计算的生成式AI上链、硬件加速、软硬协同等主题展开讨论&#xff0c;来自复旦大学、清华大学、北京…

主题公园- 海豹主题式风格餐厅设计【AIGC应用】

业务背景&#xff1a;海洋馆针对细分客群增设一个打卡主题点位&#xff0c;以海豹主题式餐厅为打卡卖点&#xff0c;效果参见海豹主题式风格。 AIGC概念图制作平台&#xff1a;&#xff08;可灵&#xff09; https://klingai.kuaishou.com/ 关键词&#xff1a; 海豹主题餐厅…