pytorch cv自带预训练模型再微调

news2024/11/24 3:10:06

参考:
https://pytorch.org/docs/0.4.1/torchvision/models.html
https://zhuanlan.zhihu.com/p/436574436
https://blog.csdn.net/u014297502/article/details/125884141
在这里插入图片描述

Network	Top-1 error	Top-5 error
AlexNet	43.45	20.91
VGG-11	30.98	11.37
VGG-13	30.07	10.75
VGG-16	28.41	9.62
VGG-19	27.62	9.12
VGG-11 with batch normalization	29.62	10.19
VGG-13 with batch normalization	28.45	9.63
VGG-16 with batch normalization	26.63	8.50
VGG-19 with batch normalization	25.76	8.15
ResNet-18	30.24	10.92
ResNet-34	26.70	8.58
ResNet-50	23.85	7.13
ResNet-101	22.63	6.44
ResNet-152	21.69	5.94
SqueezeNet 1.0	41.90	19.58
SqueezeNet 1.1	41.81	19.38
Densenet-121	25.35	7.83
Densenet-169	24.00	7.00
Densenet-201	22.80	6.43
Densenet-161	22.35	6.20
Inception v3	22.55	6.44

自带预训练模型对手写数字微调

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# 定义数据转换和加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
# 加载 torchvision 中的预训练模型
model = torchvision.models.resnet18(pretrained=True)

#下面两行冻结模型的所有参数;不加这两行就是全量训练
for param in model.parameters():
    param.requires_grad = False
    
#修改模型的最后一层/分类器,用于适应特定任务的类别数量
num_features = model.fc.in_features
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 修改输入通道数为1  (因为resnet18输入图像维度3,MNIST图片维度是1)
model.fc = nn.Linear(num_features, 10)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
##训练
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        # print(inputs.shape)
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0

### 验证
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %.2f %%' % (100 * correct / total))

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/761149.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

动态规划完全背包之518零钱兑换 II

题目: 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 …

【Kafka中间件】ubuntu 部署kafka,实现Django生产和消费

原文作者:我辈李想 版权声明:文章原创,转载时请务必加上原文超链接、作者信息和本声明。 文章目录 前言一、Kafka安装1.下载并安装Java2.下载和解压 Kafka3.配置 Kafka4.启动 Kafka5.创建主题和生产者/消费者6.发布和订阅消息 二、KafkaDjang…

红黑树底层原理【白话版】

一、红黑树——特殊的平衡二叉搜索树 定义:红黑树是一种特殊的平衡二叉搜索树。我们用它来排列数据,并方便以后快速检索数据。 估计看到这句话,你就崩溃了,因为这话说了等于没说。 先观察这个图。 球要不是黑色,要不…

console的奇妙用法

console的奇妙用法 console.log是调试 JavaScript 代码的最佳方法之一。但是本文将介绍几个与console交互的更好方法。 在vscode或者的其他ide中输入console可以看到里边提供了非常多的方法。 虽然我们通常都是用console.log,但是使用其他可以使调试过程变得更加容…

分布式链路追踪

文章目录 1、背景2、微服务架构下的问题3、链路追踪4、核心概念5、技术选型对比6、zipkin 1、背景 随着互联网业务快速扩展,软件架构也日益变得复杂,为了适应海量用户高并发请求,系统中越来越多的组件开始走向分布式化,如单体架构…

流水灯——FPGA

文章目录 前言一、流水灯介绍二、系统设计1.模块框图2.RTL视图 三、源码四、效果五、总结六、参考资料 前言 环境: 1、Quartus18.0 2、vscode 3、板子型号:EP4CE6F17C8 要求: 每隔0.2s循环亮起LED灯 一、流水灯介绍 从LED0开始亮起到LED3又回…

如何定制自己的应用层协议?|面向字节流|字节流如何解决黏包问题?如何将字节流分成数据报?

前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助。 高质量干货博客汇总https://blog.csdn.net/yu_cblog/c…

基于ssm的社区生活超市的设计与实现

博主介绍:专注于Java技术领域和毕业项目实战。专注于计算机毕设开发、定制、文档编写指导等,对软件开发具有浓厚的兴趣,工作之余喜欢钻研技术,关注IT技术的发展趋势,感谢大家的关注与支持。 技术交流和部署相关看文章…

SpringBoot拦截器

一、SpringBoot拦截器介绍 Spring Boot中的拦截器是一种用于在处理请求之前或之后执行特定操作的组件。拦截器通常用于实现对请求进行预处理、日志记录、权限验证等功能。 在Spring Boot中,可以使用HandlerInterceptor接口来定义自己的拦截器,并通过配…

流水灯实现

文章目录 一、流水灯二、代码实现三、引脚分配 一、流水灯 流水灯指的是LED像水流一样点亮,即LED依次点亮但不立刻熄灭,等到4个LED都点亮后,再把所有灯一次性熄灭。 二、代码实现 module horse_led(input wire clk,input wire rst_n,output…

记录管理系统

简单的记录管理系统,适用于保存表格数据,可以用来替代Excel软件保存数据,提供可视化拖动组件用于自定义数据列,数据存到数据库,相比于Excel,更易保存,易搜索。 例如创建合同记录数据&#xff0…

【电子学会】2023年05月图形化四级 -- 计算圆的面积和周长

计算圆的面积和周长 编写程序计算圆的面积和周长。输入圆的半径,程序计算出圆的面积和周长,圆的面积等于3.14*半径*半径;圆的周长等于2*3.14*半径。 1. 准备工作 (1)保留舞台中的小猫角色和白色背景; 2…

MySQL数据表高级操作

一、克隆/复制数据表二、清空表,删除表内的所有数据删除小结 三、创建临时表四、MySQL中6种常见的约束1、外键的定义2、创建外键约束作用3、创建主表test44、创建从表test55、为主表test4添加一个主键约束。主键名建议以"PK_”开头。6、为从表test5表添加外键&…

Html利用Canvas绘制图形

今天接到粉丝私信,询问是否可以通过Canvas绘制一些图形,然后根据粉丝提供的模板图,通过Canvas进行模拟绘制,通过分析发现,图形虽然相对简单,但是如果不借助相应的软件,纯代码绘制还是稍微费些时…

机器学习:self supervised learning- Recent Advances in pre-trained language models

背景 Autoregressive Langeuage Models 不完整的句子,预测剩下的空的词语 sentence completion Transformer-based ALMs Masked language models-MLMs 预训练模型能将输入文本转成hidden feature representation 模型参数最开始是从预训练模型中拿到&#xf…

如何快速制作一个奶茶店小程序商城

如果你是一个奶茶店的老板,你可能会考虑开设一个小程序商城来增加销售渠道和提升品牌形象。那么,如何快速制作一个奶茶店小程序商城呢?下面我们将介绍一个简单的步骤供你参考。 首先,你需要登录乔拓云平台进入商城后台管理页面。在…

数据结构真题

数据结构真题 1. A. Bills of Paradise 线段树并查集四个操作: D x。标记大于等于 x 的第一个未标记的 a i a_i ai​;若没有,则不操作.F x。查询大于等于 x 的第一个未标记的 a i a_i ai​;若没有,则输出 1 0 12…

《UNUX环境高级编程》(9)进程关系

1、前言 2、终端登录 在早期的UNIX系统,用户用哑终端(用硬连接到主机)进行登录,因为连接到主机上的终端设备数是固定的,所以同时登录数也就有了已知的上限。 随着位映射图像终端的出现,开发出了窗口系统&…

数学分析:对偶映射

这个其实就是我们一致讨论的对偶映射,换了个马甲,差点认不出来了。本来是V->R 要变成U->R,就需要一个反向的V*->U*的映射。 注意这个式子,t属于U,phit转到了V,但是坐标也发生了变化,这…

2023西南赛区ciscn -- do you like read

Attack 打开后一个商城页面 在login as admin那里有个登录页面,账号admin,密码爆破即可得到admin123 也可以在book.php?bookisbn1进行sql注入得到密码,这里发现是没有注入waf的 登录进来是一个Book List的管理页面,同时在审计源…