pytorch-迁移学习

news2025/1/16 1:04:46

目录

  • 1. 宝可梦数据集训练的问题
  • 2. 迁移学习
  • 3. 迁移学习实现
  • 4. 完整代码

1. 宝可梦数据集训练的问题

宝可梦数据总共有1000多张,对于resnet18网络来说数据量是不够的,训练时很容易出现过拟合,那么如何解决这个问题呢?

宝可梦数据集训练传送门

2. 迁移学习

下图以⚪和■二分类为例
第一种模式,下图中第一列图,直接使用4个数据训练,有三种可能,中间一种肯定是最好的。
第二种模式,宝可梦数据集和imagenet中1000多种数据有相似甚至包含其中,那么是否可以利用imagenet的模型来提高宝可梦数据集训练模型的性能呢?答案肯定的,如图中第二列图,使用在已有的相似知识基础上训练,大大提升了中间分割线的可能。
可以说利用已有模型权重,训练特定分类任务的操作就叫迁移学习。
在这里插入图片描述

3. 迁移学习实现

见下图,只需替换掉输出层即可。
在这里插入图片描述
针对上文的宝可梦数据集模型训练,如何实现迁移学习呢?
步骤:

  • 加载标准resnet18网络
from    torchvision.models import resnet18
  • 增加预训练标识
trained_model = resnet18(pretrained=True)
  • 取resnet前17层,并增加一个线性输出层
    *list(trained_model.children())[:-1]取前17层,并将数据打散
    代码中增加了一个打平操作,将resnet18第17层输出的[b, 512, 1, 1]转换到 [b, 512]
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
                          Flatten(), # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)

经验证准确率从84%提升到94%,因此迁移学习有助于提升小数据集性能的提升和防止过拟合。

4. 完整代码

train_transfer.py

import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoader

from    pokemon import Pokemon
# from    resnet import ResNet18
from    torchvision.models import resnet18

from    utils import Flatten

batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)


train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                          num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


viz = visdom.Visdom()

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def main():

    # model = ResNet18(5).to(device)
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
                          Flatten(), # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)
    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()


    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step, (x,y) in enumerate(train_loader):

            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc> best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')


    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)





if __name__ == '__main__':
    main()

utils.py

from    matplotlib import pyplot as plt
import  torch
from    torch import nn

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1958989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常见的几种数据标注类型

数据标注是机器学习和人工智能项目中一个至关重要的步骤,它帮助算法理解输入数据中的关键特征。根据不同的应用场景和技术需求,数据标注可以分为多种类型。 以下是一些常见的数据标注类型: 图像标注: 边界框:在物体周…

手撕数据结构---栈和队列的概念以及实现

栈的概念: 栈:⼀种特殊的线性表,其只允许在固定的⼀端进⾏插⼊和删除元素操作。进⾏数据插⼊和删除操作的⼀端称为栈顶,另⼀端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)的原则。 压栈…

Doris-计算特性

1 全新优化器 1.1 如何开启1.2 统计信息 1.2.1 使用ANALYZE语句手动收集1.2.1 自动收集1.2.3 作业管理1.3 会话变量及配置项调优参数2 Join相关 2.1 支持的Join算子2.2 支持的shuffle方式 2.2.1 Broadcast Join2.2.2 Shuffle Join2.2.3 Bucket Shuffle Join 2.2.3.1 原理2.2.3.…

【CTFWP】ctfshow-web40

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 题目介绍:题目分析:payload:payload解释:payload2:payload2解释:flag 题目介绍: …

python-分享篇-用Python分析文本数据的词频

上次批量提取了上市公司主要业务信息,要分析这些文本数据,就需要做文本词频分析。由于中文不同于英文,词是由一个一个汉字组成的,而英文的词与词之间本身就有空格,所以中文的分词需要单独的库才能够实现,常…

2024年必备技能:小红书笔记评论自动采集,零基础也能学会的方法

摘要: 面对信息爆炸的2024年,小红书作为热门社交平台,其笔记评论成为市场洞察的金矿。本文将手把手教你,即便编程零基础,也能轻松学会利用Python自动化采集小红书笔记评论,解锁营销新策略,提升…

pmp学习交流组队~

首先,来看看什么是PMP PMP指的是项目管理专业人士资格认证。它是由美国项目管理协会(Project Management Institute(PMI)发起的,严格评估项目管理人员知识技能是否具有高品质的资格认证考试。 pmp备考攻略本人推荐的参考资料比较多&#xff0…

MySQL 9 安装第1辑-版本选择和安装包获取

一、MySQL 9 版本选择 在准备安装MySQL时,选择合适的版本和分发格式至关重要。首先,需要决定是安装长期支持(LTS)系列版本还是创新系列版本。长期支持版本(如MySQL 8.x LTS)专注于稳定性、性能优化和安全性…

RocketMQ知识总结(基本原理)

文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 基本原理 总体架构图 零拷贝 零拷贝技术是一个思想,指…

蓝屏事件一些想法

影响全球的蓝屏事件 2024年7月19日发生了大量windows操作系统电脑蓝屏的事情,造成了全球级别的影响。其中国外的影响最大,甚至像医院、银行、航班等与人民生活密切相关的行业都受到了本次影响。导致全球数千架次航班被取消,数万架次航班延误…

Java | Leetcode Java题解之第303题区域和检索-数组不可变

题目&#xff1a; 题解&#xff1a; class NumArray {int[] sums;public NumArray(int[] nums) {int n nums.length;sums new int[n 1];for (int i 0; i < n; i) {sums[i 1] sums[i] nums[i];}}public int sumRange(int i, int j) {return sums[j 1] - sums[i];} }…

【Web开发手礼】探索Web开发的秘密(十三)-Vue(3)好友列表、登录

前言 主要介绍了好友列表、登录界面所涉及的vue知识点&#xff01;&#xff01;&#xff01; 好友列表 从云端API读取数据信息 地址 https://app165.acapp.acwing.com.cn/myspace/userlist/方法&#xff1a;GET是否验证jwt&#xff1a;否输入参数&#xff1a;无返回结果&…

C++自定义接口类设计器

关键代码 #include "mainwindow.h" #include "ui_mainwindow.h" #include <QStringListModel>MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);// C基础数据类型列表QStringList typ…

求.netcore 按模板导出pdf免费插件,来谈谈。

&#x1f3c6;本文收录于《CSDN问答解惑-专业版》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收…

图书管理系统设计

设计一个图书管理系统时&#xff0c;我们需要考虑系统的基本功能、用户需求、技术选型以及数据的安全性和完整性。下面是一个基本的图书管理系统的设计概览&#xff1a; 1. 系统目标 管理图书信息&#xff1a;添加、删除、修改图书信息。借阅管理&#xff1a;处理借书、还书流…

低代码应用版本管理能力探讨

低代码平台为开发者提供易用的可视化、定制化开发能力&#xff0c;无需编写原生代码或者只有少量代码编写就能实现需求&#xff0c;从而带来开发门槛的降低&#xff0c;开发效率的提升。低代码作为提升应用研发生产力的关键型技术&#xff0c;成为企业数字化转型的助推器。 低代…

StarRock3.3 安装部署(存算分离、存算一体保姆式教程)

服务器前置要求&#xff1a; 1、内存>32GB 2、JDK 8 is not supported, please use JDK 11 or 17 1、安装 wget https://releases.starrocks.io/starrocks/StarRocks-3.3.0.tar.gz tar zxvf StarRocks-3.3.0.tar.gz 2、FE服务启动 2.1 配置FE节点(默认配置&#xff0c;…

C#知识|文件与目录操作:文本读写操作

哈喽,你好啊,我是雷工! 今天学习文件与目录的操作,以下为文本读写操作的学习笔记。 01 文件操作说明 1.1、数据的存取方式 数据库:适合存取大量且关系复杂并有序的数据; 文件:适合存取大量但数据关系简单的数据,像系统的日志文件; 1.2、文件存取的优点 ①:读取操…

根据ip地址能查询出具体地址吗?

在数字化时代&#xff0c;互联网已成为我们日常生活不可或缺的一部分&#xff0c;而IP地址作为网络世界的“身份证”&#xff0c;承载着每一台设备在网络中的唯一标识。你是否曾经好奇&#xff0c;通过一串看似无意义的数字组合——IP地址&#xff0c;我们究竟能否揭开其背后的…

springboot校园失物招领系统-计算机毕业设计源码17082

目 录 摘要 1 绪论 1.1 研究背景 1.2 研究意义 1.3论文结构与章节安排 2 相关技术介绍 2.1 B/S结构 2.2 Spring Boot框架 2.3 MySQL数据库 3系统分析 3.1 可行性分析 3.2 系统流程分析 3.2.1 数据新增流程 3.2.2 数据删除流程 3.3 系统功能分析 3.3.1 功能性分…