卷积神经网络|迁移学习-猫狗分类完整代码实现

news2024/11/25 14:43:25

还记得这篇文章吗?迁移学习|代码实现

在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。

同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!

我们仍然按照这个步骤开始我们的模型的训练

  • 准备一个可迭代的数据集

  • 定义一个神经网络

  • 将数据集输入到神经网络进行处理

  • 计算损失

  • 通过梯度下降算法更新参数

import torch import torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom torchvision import models

数据集准备

cifar10_train = torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = True,    download = True)cifar10_test=torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = False,    download = True)
transform = transforms.Compose([        transforms.ToTensor(),        transforms.Resize((224,224))    ])

cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)

数据集使用CIFAR-10数据集中的猫和狗

CIFAR-10数据集类别

种类       标签

  • plane       0

  • car           1

  • bird         2

  • cat           3

  • deer         4

  • dog          5

  • frog         6

  • horse       7

  • ship         8

  • truck        9

可以看到其中cat和dog的标签分别为3和5

借助:

[3,5].index(label)

我们可以将cat标签变为0dog标签变为1,从而回到二分类问题。

举个例子:

>>> [3,5].index(3)0>>> [3,5].index(5)1

定义模型

参考这篇文章:迁移学习|代码实现

#网络搭建network=models.resnet18(pretrained=True)
for param in network.parameters():    param.requires_grad=False
network.fc=nn.Linear(512,2)#损失函数criterion=nn.CrossEntropyLoss()#优化器optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")network=network.to(device)

训练模型:

for epoch in range(10):    total_loss = 0    total_correct = 0    for batch in train_loader:   # Get batch        images, labels =batch        images=images.to(device)        labels=labels.to(device)                    optimizer.zero_grad()  #告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度        preds = network(images)        loss = criterion(preds, labels)        loss.backward()        optimizer.step()                total_loss += loss.item()        _,prelabels=torch.max(preds,dim=1)        total_correct += int((prelabels==labels).sum())    accuracy = total_correct/len(cifar2_train)    print("Epoch:%d  ,  Loss:%f  , Accuracy:%f "%(epoch,total_loss,accuracy))
  • Epoch:0  ,  Loss:78.549439  , Accuracy:0.788900

  • Epoch:1  ,  Loss:77.828066  , Accuracy:0.801500

  • Epoch:2  ,  Loss:66.151785  , Accuracy:0.828100

  • Epoch:3  ,  Loss:76.204446  , Accuracy:0.816800

  • Epoch:4  ,  Loss:68.886606  , Accuracy:0.828100

  • Epoch:5  ,  Loss:71.129405  , Accuracy:0.821200

  • Epoch:6  ,  Loss:66.096364  , Accuracy:0.829900

  • Epoch:7  ,  Loss:65.504227  , Accuracy:0.827700

  • Epoch:8  ,  Loss:76.303878  , Accuracy:0.817100

  • Epoch:9  ,  Loss:70.546953  , Accuracy:0.820700

测试模型:

correct=0total=0network.eval()with torch.no_grad():    for batch in test_loader:        imgs,labels=batch        imgs=imgs.cuda()        labels=labels.cuda()                preds=network(imgs)        _,prelabels=torch.max(preds,dim=1)        #print(prelabels.size())        total=total+labels.size(0)        correct=correct+int((prelabels==labels).sum())    #print(total)    accuracy=correct/total    print("Accuracy: ",accuracy)

Accuracy:  0.8025

这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。

除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1366462.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

外延炉及其相关的小知识

外延炉是一种用于生产半导体材料的设备,其工作原理是在高温高压环境下将半导体材料沉积在衬底上。 硅外延生长,是在具有一定晶向的硅单晶衬底上,生长一层具有和衬底相同晶向的电阻率且厚度不同的晶格结构完整性好的晶体。 外延生长的特点&…

Linux实验——页面置换算法模拟

页面置换算法模拟 【实验目的】 (1)理解虚拟内存管理的原理和技术。 (2)掌握请求分页存储管理的思想。 (3)理解常用页面置换算法的思想。 【实验原理/实验基础知识】 存储器是计算机系统的重要资源之…

腾讯面试总结

腾讯 一面 mysql索引结构?redis持久化策略?zookeeper节点类型说一下;zookeeper选举机制?zookeeper主节点故障,如何重新选举?syn机制?线程池的核心参数;threadlocal的实现&#xff…

揭开JavaScript数据类型的神秘面纱

🧑‍🎓 个人主页:《爱蹦跶的大A阿》 🔥当前正在更新专栏:《VUE》 、《JavaScript保姆级教程》、《krpano》 ​ ​ ✨ 前言 JavaScript作为一门动态类型语言,其数据类型一直是开发者们关注的话题。本文将深入探讨Jav…

C语言算法(二分查找、文件读写)

二分查找 前提条件&#xff1a;数据有序&#xff0c;随机访问 #include <stdio.h>int binary_search(int arr[],int n,int key);int main(void) {}int search(int arr[],int left,int right,int key) {//边界条件if(left > right) return -1;//int mid (left righ…

MidTool的AIGC与NFT的结合-艺术创作和版权保护的革新

在数字艺术和区块链技术的交汇点上&#xff0c;NFT&#xff08;非同质化代币&#xff09;正以其独特的方式重塑艺术品的收藏与交易。将MidTool&#xff08;https://www.aimidtool.com/&#xff09;的AIGC&#xff08;人工智能生成内容&#xff09;创作的图片转为NFT&#xff0c…

数据库基础知识1

目录 数据库的使用 登录mysql 命令语法 常用命令 ​编辑 navicat建立连接 mysql授权管理命令 ​编辑mysql权限 数据导入导出 实例 数据导出 未登录 已经登录 导出导入的代码对比 ​编辑 导入导出的一个坑 python的导入导出 数据库基础知识 特点 需要掌握的程…

嵌入式——循环队列

循环队列 (Circular Queue) 是一种数据结构(或称环形队列、圆形队列)。它类似于普通队列,但是在循环队列中,当队列尾部到达数组的末尾时,它会从数组的开头重新开始。这种数据结构通常用于需要固定大小的队列,例如计算机内存中的缓冲区。循环队列可以通过数组或链表实现,…

使用Docker-compose快速构建Nacos服务

在微服务架构中&#xff0c;服务的注册与发现扮演着至关重要的角色。Nacos&#xff08;Naming and Configuration Service&#xff09;是阿里巴巴开源的服务注册与发现组件&#xff0c;致力于支持动态配置管理和服务发现。最近&#xff0c;一位朋友表达了对搭建一套Nacos开发环…

速卖通店铺销量飙升:掌握自养号测评(补单),轻松提升销售量

很多卖家在经营速卖通店铺时&#xff0c;都希望能提高自己店铺的曝光率。但对于一些新手卖家来说&#xff0c;可能不太清楚曝光率的具体含义以及如何提升。那么&#xff0c;让我们一起来探讨一下这个问题。 曝光率&#xff0c;简而言之&#xff0c;是指您的店铺和产品展示给顾…

springboot git配置文件自动刷新失败问题排查

http://{ip}:{port}/refresh 说明&#xff1a;springBoot版本是1.5.9&#xff0c;接口路径与2.x&#xff0c;不同 路径区别&#xff1a;/refresh VS /actuator/refresh 用postman调用refresh接口刷新git配置&#xff0c;报错如下&#xff0c;没有权限 在服务本地启动&#…

【Java】2023年业务实践中遇到的所有OOM情况及实战总结

OOM分析&实战 引言&#xff1a;一、JVM内存结构二、JVM OOM错误情况三、实践案例一案例二案例三 四、总结五、分析工具推荐六、参考文献 引言&#xff1a; 在Java开发中&#xff0c;随着应用程序变得越来越复杂&#xff0c;内存管理问题也变得愈加重要。而在JVM中的"O…

笔试案例2

文章目录 1、笔试案例22、思维导图 1、笔试案例2 09&#xff09;查询学过「张三」老师授课的同学的信息 selects.*,c.cname,t.tname,sc.score from t_mysql_teacher t, t_mysql_course c, t_mysql_student s, t_mysql_score sc where t.tidc.cid and c.cidsc.cid and sc.sids…

数据结构-测试6

一、判断题 1.若一个栈的输入序列为{1, 2, 3, 4, 5}&#xff0c;则不可能得到{3, 4, 1, 2, 5}这样的出栈序列。&#xff08;T&#xff09; 3比4先进&#xff0c;所以3比4后出&#xff0c;所以不可能得到 2. 在二叉排序树中&#xff0c;每个结点的关键字都比左孩子关键字大&…

【Qt开发】PyQt6--标签控件

标签控件 Qlabel设置标签文本文本的对齐方式为标签设置超链接为标签设置图片获取标签文本 Qlabel QLabel标签控件&#xff0c;用于显示用户不能编辑的文本&#xff0c;主要起提示的作用 设置标签文本 文本的对齐方式 通过这可以设置文本对齐方式 为标签设置超链接 勾选以上…

NGS基因测序(panel)报告解读数据库汇总

今天我们来梳理一下肿瘤基因报告解读常见的数据库&#xff0c;大家有机会可以自己查询并且解读&#xff0c;涉及到的数据库有dbSNP数据库 、gnomAD数据库、ExAC数据库、1000 Genomes、HGMD 数据库、OMIM数据库、ClinVar数据库、InterVar数据库 、ClinGen数据库、GeneReviews数据…

大创项目推荐 深度学习图像风格迁移

文章目录 0 前言1 VGG网络2 风格迁移3 内容损失4 风格损失5 主代码实现6 迁移模型实现7 效果展示8 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习图像风格迁移 - opencv python 该项目较为新颖&#xff0c;适合作为竞赛课题…

自学编程资源收集

Java&#xff0c;Python&#xff0c;C&#xff0c;JavaScript,SpringBoot&#xff0c;Vue,MySql等各种编程资料收集 mksz712-系统玩转OpenGLAI&#xff0c;实现各种酷炫视频特效mksz709-从0到1训练私有大模型 &#xff0c;企业急迫需求&#xff0c;抢占市场先机~8mksz702-Chat…

布偶猫必囤主食冻干有哪些?三款K9、sc、希喂主食冻干深度测评!

喂养布偶猫的小诀窍&#xff1a;既要满足其食肉习性&#xff0c;又需关注其敏感肠胃。主食冻干是理想选择&#xff0c;它既符合猫咪天然的饮食结构&#xff0c;又采用新鲜生肉为原料。搭配其他营养元素&#xff0c;既美味又营养&#xff0c;还能增强抵抗力。我们将为您测评市场…

IPv6路由协议---IPv6动态路由(RIPng)

IPv6动态路由协议 动态路由协议有自己的路由算法,能够自动适应网络拓扑的变化,适用于具有一定数量三层设备的网络。缺点是配置对用户要求比较高,对系统的要求高于静态路由,并将占用一定的网络资源和系统资源。 路由表和FIB表 路由器转发数据包的关键是路由表和FIB表,每…