微调及代码

news2024/11/15 19:48:47

一、微调:迁移学习(transfer learning)将从源数据集学到的知识迁移到目标数据集

二、步骤

1、在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型

2、创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。

3、向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。

4、在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

5、目标数据集比源数据集小得多时,微调有助于提高模型的泛化能力。就相当于在别人训练好的基础上训练

三、网络架构:神经网络一般分为两块:特征抽取(将原始像素变成容易线性分割的特征)和线性分类器

四、训练

1、微调是在目标数据集上的正常训练任务,但使用更强的正则化(有更强的lr和更少的epoch)

2、源数据集远复杂于目标数据,通常微调效果会更好

3、源数据集中可能也有目标数据中的部分标号

4、固定底层训练高层

五、总结

1、迁移学习将从源数据集中学到的知识迁移到目标数据集,微调是迁移学习的常见技巧。

2、除输出层外,目标模型从源模型中复制所有模型设计及其参数,并根据目标数据集对这些参数进行微调。但是,目标模型的输出层需要从头开始训练。

3、通常,微调参数使用较小的学习率,而从头开始训练输出层可以使用更大的学习率。

六、代码

1、导入数据

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

2、定义和初始化模型

pretrained_net = torchvision.models.resnet18(pretrained=True)

finetune_net = torchvision.models.resnet18(pretrained=True)
#最后一层类别为2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

3、微调模型

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1920449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python实战因果推断】31_双重差分2

目录 Canonical Difference-in-Differences Diff-in-Diff with Outcome Growth Canonical Difference-in-Differences 差分法的基本思想是,通过使用受治疗单位的基线,但应用对照单位的结果(增长)演变,来估算缺失的潜…

PostgreSQL 中如何处理数据的批量更新和事务日志管理?

文章目录 PostgreSQL 中数据的批量更新和事务日志管理 PostgreSQL 中数据的批量更新和事务日志管理 在数据库的世界里,数据的批量更新和事务日志管理就像是一场精心编排的舞蹈,需要精准的步伐和协调的动作。对于 PostgreSQL 而言,这两个方面…

LLM基础模型系列:Fine-Tuning总览

由于对大型语言模型,人工智能从业者经常被问到这样的问题:如何训练自己的数据?回答这个问题远非易事。生成式人工智能的最新进展是由具有许多参数的大规模模型驱动的,而训练这样的模型LLM需要昂贵的硬件(即许多具有大量…

常见 Web漏洞分析与防范研究

前言: 在当今数字化时代,Web应用程序扮演着重要的角色,为我们提供了各种在线服务和功能。然而,这些应用程序往往面临着各种潜在的安全威胁,这些威胁可能会导致敏感信息泄露、系统瘫痪以及其他不良后果。 SQL注入漏洞 …

自主研发接口测试框架

测试任务:将以前完成的所有的脚本统一改写为unitest框架方式 1、需求原型 1.1 框架目录结构 V1.0:一般的设计思路分为配置层、脚本层、数据层、结果层,如下图所示 V 2.0:加入驱动层testdriver 1.2 框架各层需要完成的工作 1、配…

【CT】LeetCode手撕—70. 爬楼梯

目录 题目1- 思路2- 实现⭐70. 爬楼梯——题解思路 3- ACM实现 题目 原题连接&#xff1a;70. 爬楼梯 1- 思路 思路 爬楼梯 ——> 动规五部曲 2- 实现 ⭐70. 爬楼梯——题解思路 class Solution {public int climbStairs(int n) {if(n<1){return 1;}// 1. 定义 dp 数…

html5——CSS基础选择器

目录 标签选择器 类选择器 id选择器 三种选择器优先级 标签指定式选择器 包含选择器 群组选择器 通配符选择器 Emmet语法&#xff08;扩展补充&#xff09; 标签选择器 HTML标签作为标签选择器的名称&#xff1a; <h1>…<h6>、<p>、<img/> 语…

数据平滑处理(部分)

一、 移动平均&#xff08;Moving Average&#xff09; 是一种最简单的数据平滑方法&#xff0c;用于平滑时间序列数据。它通过计算一定窗口内数据点的平均值来减少噪音&#xff0c;同时保留数据的趋势。移动平均包括简单移动平均&#xff08;SMA&#xff09;或指数加权移动平均…

初始网络知识

前言&#x1f440;~ 上一章我们介绍了使用java代码操作文件&#xff0c;今天我们来聊聊网络的一些基础知识点&#xff0c;以便后续更深入的了解网络 网络 局域网&#xff08;LAN&#xff09; 广域网&#xff08;WAN&#xff09; 路由器 交换机 网络通信基础 IP地址 端…

可观察性优势:掌握当代编程技术

反馈循环是我们开发人员工作的关键。它们为我们提供信息&#xff0c;并让我们从用户过去和现在的行为中学习。这意味着我们可以根据过去的反应进行主动开发。 TestComplete 是一款自动化UI测试工具&#xff0c;这款工具目前在全球范围内被广泛应用于进行桌面、移动和Web应用的…

“闭门造车”之多模态思路浅谈:自回归学习与生成

©PaperWeekly 原创 作者 | 苏剑林 单位 | 科学空间 研究方向 | NLP、神经网络 这篇文章我们继续来闭门造车&#xff0c;分享一下笔者最近对多模态学习的一些新理解。 在前文《“闭门造车”之多模态思路浅谈&#xff1a;无损》中&#xff0c;我们强调了无损输入对于理想的…

压缩文件的解析方式

我们常用的压缩文件有两种&#xff1a;后缀为.zip或者.rar&#xff0c;接下来将介绍解析两种压缩文件的代码。需要用到三个jar包&#xff1a;commons-io-2.16.1.jar、junrar-7.5.5.jar、slf4j-api-2.0.13.jar&#xff0c;可以在官网下载&#xff0c;也可以发私信。 这段代码是一…

2.GAP:通用访问协议

GAP的简单理解 GAP这个名字&#xff0c;直接翻译过来不好理解。 简单点可以理解为&#xff1a; 这是蓝牙设备在互联之前&#xff0c;过程中&#xff0c;第一个用于交流的协议。在代码上&#xff0c;会给这个协议实现&#xff0c;连接参数的设置&#xff0c;连接事件的实现&am…

【算法】二叉树-迭代法实现前后中序遍历

递归的实现就是:每一次递归调用都会把函数的局部变量&#xff0c;参数值和返回地址等压入调用栈中&#xff0c;然后递归返回的时候&#xff0c;从栈顶弹出上一次递归的各项参数&#xff0c;这就是递归为什么可以返回上一层位置的原因 可以用栈实现二叉树的前中后序遍历 1. 前序…

【数学趣】拉窗帘模型之求面积引发的6个解法

抖音上推了一个趣题 题 求橙色部分的面积 蓝色部分是2个正方形。大的正方形边长为6。&#xff08;小的正方形一半被一个黄色三角形遮住了一半&#xff09; 答案 18 解法1&#xff1a;拉窗帘 先写一个代号&#xff0c;方便证明&#xff0c;H G 代表正方形。&#xff08;G…

AV1 编码标准中帧内预测技术详细说明

AV1 编码标准帧内预测 AV1&#xff08;AOMedia Video 1&#xff09;是一种开源的视频编码格式&#xff0c;旨在提供比现有标准更高的压缩效率和更好的视频质量。在帧内预测方面&#xff0c;AV1相较于其前身VP9和其他编解码标准&#xff0c;如H.264/AVC和H.265/HEVC&#xff0c;…

暑假第一次作业

第一步&#xff1a;给R1,R2,R3,R4配IP [R1-GigabitEthernet0/0/0]ip address 192.168.1.1 24 [R1-Serial4/0/0]ip address 15.0.0.1 24 [R2-GigabitEthernet0/0/0]ip address 192.168.2.1 24 [R2-Serial4/0/0]ip address 25.0.0.1 24 [R3-GigabitEthernet0/0/0]ip address 192.…

【Mutilism用74ls192和与非门设计3进制24进制加法计数器2荔枝】2022-5-10

缘由【数电 数字逻辑】如何用74ls192和与非门设计任意进制加法计数器&#xff1f;-嵌入式-CSDN问答

Qt学生管理系统(付源码)

Qt学生管理系统 一、前言1.1 项目介绍1.2 项目目标 2、需求说明2.1 功能性说明2.2 非功能性说明 三、UX设计3.1 登录界面3.2 学生数据展示3.3 信息插入和更新 三、架构说明3.1 客户端结构如下3.2 数据流程图3.2.1 数据管理3.2.2 管理员登录 四、 设计说明3.1 数据库设计3.2 结构…

基于Python+Flask+MySQL的新冠疫情可视化系统

基于PythonFlaskMySQL的新冠疫情可视化系统 FlaskMySQL 基于PythonFlaskMySQL的新冠疫情可视化系统 项目主要依赖前端&#xff1a;layui&#xff0c;Echart&#xff0c;后端主要是Flask&#xff0c;系统的主要支持登录注册&#xff0c;Ecahrt构建可视化图&#xff0c;可更换主…