[PyTorch][chapter 52][迁移学习]

news2024/10/6 22:33:59

前言:

     迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

      这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一  简介

    

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

      将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2  网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3  特征迁移:

     将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

     word2vec

1.4 参数迁移:

       将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二  Flatten

    作用:

     输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023

@author: chengxf2
"""

import torch
from torch import optim,nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten,self).__init__()
        
    
    def forward(self, x):
        
        a = torch.tensor(x.shape[1:])
        #dim 中 input 张量的每一行的乘积。
        shape = torch.prod(a).item()
        #print("\n ---new shape--- ",shape)
        return x.view(-1,shape)

三 迁移学习

   torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

   from torchvision.models import resnet152
  from torchvision.models import resnet18

   注意:

          现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

         

分类器分类类型
已有分类器[猫,狗,鸡,鸭】
新分类器[猫,狗]

     

   

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18

from util import Flatten

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    trained_model = resnet152(pretrained=True)
    
    model = nn.Sequential(*list(trained_model.children())[:-1],
        Flatten(),
        nn.Linear(in_features=2048, out_features=5))
    
   
    
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    
  
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
             
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()


参考:
https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/885565.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何用chatGPT回答女朋友的死亡问题

引言 在爱情的迷雾中,女友的死结问题常常令人犯难。然而,借助ChatGPT的智慧,或许能够开辟一片全新的天地。其独到的见解和无限的可能性,或将为我们的情感困惑带来意想不到的解答。正如科技在塑造我们的生活,也或许能为…

家居行业,瞄准软文营销

对于很多家居品牌来说,传统营销形式越来越无法打动他们,在以渠道为王的环境下,家居品牌更需要思考地是:要带给消费者新的价值和体验究竟是什么? 因为在目前的大环境下,在内容信息上展现的生动性、直接性上、…

Spark SQL优化:NOT IN子查询优化解决

背景 有如下的数据查询场景。 SELECT a,b,c,d,e,f FROM xxx.BBBB WHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} AND predict_type not IN ( SELECT distinct a FROM xxx.AAAAAWHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} ) 分析 通过查看SQL语句的执行计划基本…

全面管控企业资产,这个小妙招做到了!

随着全球金融市场的复杂性增加以及资产多样化的趋势,确保资产的最大价值和最小风险已经成为一个迫切的需求。资产管理系统能够更好地掌握自己的财务状况,做出明智的决策。 从个人投资者到大型企业,都可以从中受益,无论是实时监控投…

腾讯云GPU服务器GN7实例NVIDIA T4 GPU卡

腾讯云GPU服务器GN7实例搭载1颗 NVIDIA T4 GPU,8核32G配置,系统盘为100G 高性能云硬盘,自带5M公网带宽,系统镜像可选Linux和Windows,地域可选广州/上海/北京/新加坡/南京/重庆/成都/首尔/中国香港/德国/东京/曼谷/硅谷…

C#如何打包EXE程序生成setup安装文件

项目结束之后,有需要将winForm程序打包成.exe文件提供给用户。 这里记录一下打包过程。 1:首先获取打包插件,如果你的VS已经安装,忽略此步骤。 点击 工具->扩展和更新,选择联机,搜索installer&#x…

AntPro 模版代码

1 ProTable 实现选择和反显 页面1 页面2 选择页面2选择之后反显到页面1 &#xff0c;且支持跨页选择。同时下次进来页面1展示的数据要反显到页面2被选中。 页面2代码 <ProTablerowKey"id"columns{columns}request{(params: any) > {const newParams {pageI…

androidstudio引入jar包

如图&#xff0c;选择project 然后在src下创建一个lib文件夹&#xff0c;将要添加到项目中的jar包粘贴lib里面&#xff0c;如图 接着选中jar包&#xff0c;右击&#xff0c;会出一个弹窗&#xff0c;选择Add As Library…&#xff0c;如图 会出现一个弹窗&#xff0c;点击OK…

安防监控视频云存储平台EasyCVRH.265转码功能更新:新增分辨率配置

安防视频集中存储EasyCVR视频监控综合管理平台可以根据不同的场景需求&#xff0c;让平台在内网、专网、VPN、广域网、互联网等各种环境下进行音视频的采集、接入与多端分发。在视频能力上&#xff0c;视频云存储平台EasyCVR可实现视频实时直播、云端录像、视频云存储、视频存储…

<kernel>kernel 6.4 USB-之-hub_event()分析

&#xff1c;kernel&#xff1e;kernel 6.4 USB-之-hub_event()分析 本文是基于linux kernel 6.4版本内核分析&#xff1b;源码下载路径&#xff1a;linux kernel 本文主要分析hub_event()函数的内容&#xff1b;hub_event()函数是Linux内核USB子系统中的一个函数&#xff0c…

面试之ReentrantLock

一&#xff0c;ReentrantLock 1.ReentrantLock是什么&#xff1f; ReentrantLock实现了Lock接口&#xff0c;是一个可重入且独占式的锁&#xff0c;和Synchronized关键字类似&#xff0c;不过ReentrantLock更灵活&#xff0c;更强大&#xff0c;增加了轮询、超时、中断、公平锁…

侯捷 C++ part2 兼谈对象模型笔记——7 reference、const、new/delete

7 reference、const、new/delete 7.1 reference x 是整数&#xff0c;占4字节&#xff1b;p 是指针占4字节&#xff08;32位&#xff09;&#xff1b;r 代表x&#xff0c;那么r也是整数&#xff0c;占4字节 int x 0; int* p &x; // 地址和指针是互通的 int& r x;…

windows电脑简单实时tts语音播报wsay;python pyttsx3语言实时播报text-to-speech;微软edge-tts 音色自然离线不实时

1、wsay 参考&#xff1a; https://github.com/p-groarke/wsay 下载安装&#xff1a; https://github.com/p-groarke/wsay/releases/tag/v1.5.0 下载exe文件&#xff0c;并把加入环境变量就可 使用 # Say something. wsay "Hello there."wsay "你好"…

图书馆管理系统、学生管理系统、交通管理系统(C语言、数据结构、java、Javaweb)

图书馆管理系统作为一个经典的项目&#xff0c;在国家、学校、等每个地方或者作为期末作品都用的非常广泛&#xff1a; C语言程序设计&#xff1a;图书馆管理系统含说明文档。 大一时C综合设计&#xff0c;当时得了96。代码纯原创&#xff0c;可直接运行&#xff0c;包含详细注…

springboot多数据源配置,看这一篇就够了

springboot下多数据源配置实现 不管是两个mysql&#xff0c;还是一个mysql一个oracle&#xff0c;都是一样的操作 目录 springboot下多数据源配置实现配置application.yml文件数据源配置类创建mapper接口创建mapper的xml配置文件 你可能会遇到的问题 配置application.yml文件 …

无涯教程-Perl - study函数

描述 此功能需要花费额外的时间来研究EXPR,以改善在EXPR上执行的正则表达式的性能。如果省略EXPR,则使用$_。实际的速度增益可能非常小,具体取决于您希望搜索字符串的次数。 您一次只能学习一种表达式或标量。 语法 以下是此函数的简单语法- study EXPRstudy返回值 此函数…

Scala 如何调试隐式转换--隐式转换代码的显示展示

方法1 在需要隐式转换的地方&#xff0c;把需要的参数显示的写出。 略方法2&#xff0c;查看编译代码 在terminal中 利用 scalac -Xprint:typer xxx.scala方法打印添加了隐式值的代码示例。 对于复杂的工程来说&#xff0c;直接跑到terminal执行 scalac -Xprint:typer xxx.…

薅羊毛!我用这款AI工具,免费拿下12张漫画头像

今天l1m0_将为大家分享一款AI生图工具&#xff0c;并介绍如何通过Pixso AI&#xff0c;用自己的照片&#xff0c;免费一键生成AI漫画头像&#xff0c;一起来看看吧。 这里我用Pixso资源社区的一组用户头像资源&#xff0c;为大家演示&#xff0c;如何快速生成AI漫画头像。 首先…

小程序开发:如何选择合适的开发工具和平台?

小程序是一种基于微信平台的轻量级应用程序&#xff0c;具有操作简便、体验流畅等优点。然而&#xff0c;对于许多中小企业来说&#xff0c;三五万的开发成本可能过高&#xff0c;让人感到犹豫。 首先&#xff0c;三五万的成本包括了开发人员的费用、服务器费用、推广费用等。对…

JS大纲简介

1 HTML中的JavaScript js引用文件可以放在两个位置&#xff0c;一种是html中的head中&#xff0c;一种是html中的body中&#xff1b;放置在这两个位置&#xff0c;有何区别呢&#xff1f; 1.1 使用<script>元素的方式 1.1.1 放置在 head 中 引用example.js文件&#…