PyTorch预训练和微调:以VGG16为例

news2024/11/22 9:59:07

文章目录

    • 预训练和微调代码
    • 测试结果
    • 参考来源

预训练和微调代码

数据集:CIFAR10
CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。数据集介绍来自:CIFAR10

在这里插入图片描述
图片来源:https://paperswithcode.com/dataset/cifar-10

预训练模型: vgg16

代码

# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 2


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

# Load pretrain model & modify it

model = torchvision.models.vgg16(weights='DEFAULT')
# # If you want to do finetuning then set requires_grad = False
# # Remove these two lines if you want to train entire model,
# # and only want to load the pretrain weights.
# for param in model.parameters():
#     param.requires_grad = False
for param in model.parameters():
    param.requires_grad = False

model.avgpool = Identity() # 站位层,使得该层啥事不做
model.classifier = nn.Sequential(nn.Linear(512, 100),
                                 nn.ReLU(),
                                 nn.Linear(100, 10)) # 修改原模型的后几层
model.to(device)


# Load Data
train_dataset = datasets.CIFAR10(
    root="dataset/", 
    train=True, 
    transform=transforms.ToTensor(), 
    download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train Network
for epoch in range(num_epochs):
    losses = []

    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        losses.append(loss.item())
        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

    print(f"Cost at epoch {epoch} is {sum(losses)/len(losses):.5f}")

# Check accuracy on training & test to see how good our model


def check_accuracy(loader, model):
    if loader.dataset.train:
        print("Checking accuracy on training data")
    else:
        print("Checking accuracy on test data")

    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        print(
            f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%"
        )

    model.train()


check_accuracy(train_loader, model)

测试结果

Checking accuracy on training data
Got 29449 / 50000 with accuracy 58.90%

可以看到本次预训练模型的导入,测试结果并不理想。但并不妨碍我们对Pytorch预训练和微调的学习。

参考来源

【1】 https://www.youtube.com/watch?v=qaDe0qQZ5AQ&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=8
【2】https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_pretrain_finetune.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/746161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

16. 替换空格

链接: 链接 题目: 请实现一个函数,把字符串中的每个空格替换成"%20"。 数据范围 0≤0≤ 输入字符串的长度 ≤1000≤1000。 注意输出字符串的长度可能大于 10001000。 样例 输入:"We are happy."输出&#xff…

python简易版的飞机大战(图片资源请自找)

# 引入pygame工具包 import pygame from pygame.locals import * import time import random import sys # 初始化pygame pygame.init() # 创建一个宽480高650的一个画布canvas canvas pygame.display.set_mode((480, 650)) # 加工图片资源 bg pygame.image.load(bg.png)# 背…

Vue3之app.config.globalProperties(定义全局变量)

使用之因 一般我们在vue开发中,常用的功能,接口等等我们都会封装起来,如何每次创建一个组件,想要使用这些封装起来的功能、接口等等都需要先引入,再通过层层调用才可以得到结果,如果我现在一遍需要调用后端…

多旋翼物流无人机节能轨迹规划(Python代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

海外品牌推广:谷歌没收录?这些常见错误你可别犯!

你们是否曾经遇到过这样的情况:你在海外市场努力推广你的品牌,但是发现谷歌搜索结果中竟然找不到你的网站或品牌?别担心,你可能犯了一些常见的错误,让谷歌把你的品牌忽略掉了。让我们来看看这些错误,确保你…

3dsmax图纸怎么加密?

对设计行业来说,公司重要的设计图纸是一个企业非常重要的核心数据是命脉,那么具有如此重要性的3Dmax设计文件怎么才能确保文件的安全,避免竞争对手骗方案或者内部人员有意无意的泄密呢? 相信有很多老板或许都会遇到这样的问题。很…

❤️创意网页:有趣的文字冒险游戏(可以无限拓展)

✨博主:命运之光 🌸专栏:Python星辰秘典 🐳专栏:web开发(简单好用又好看) ❤️专栏:Java经典程序设计 ☀️博主的其他文章:点击进入博主的主页 前言:欢迎踏入…

图片转pdf怎么在线转?看看这几种在线转方法

图片转pdf怎么在线转?图片转PDF是一个非常常见的需求,因为在很多情况下,我们需要将一些图片文件转换为PDF文件格式,以便于传输、打印或者共享。如果你想在线转换图片为PDF文件,下面就给大家推荐几种简单实用的转换方法…

Hadoop集群运行Spark应用程序

启动Spark集群 先启动hadoop,再启动Spark,具体参考链接 对Linux系统对Spark开发环境配置_Matrix70的博客-CSDN博客 运行Spark安装好以后自带的样例程序SparkPi spark-submit --class org.apache.spark.examples.SparkPi --master spark://master:7077 examples/jars/spark…

路径规划算法:基于蛇优化优化的路径规划算法- 附代码

路径规划算法:基于蛇优化优化的路径规划算法- 附代码 文章目录 路径规划算法:基于蛇优化优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要:本文主要介绍利用智能优化算法…

嵌入式软件测试笔记10 | 嵌入式软件测试中如何进行安全性分析?

10 | 嵌入式软件测试中如何进行安全性分析? 1 简介2 故障模型及后果分析(FMEA)2.1 三个步骤2.2 带来的结果优势2.3 FMEA分析过程2.3.1 描述系统及其功能2.3.2 识别潜在的故障模式2.3.3 故障模式对功能的影响2.3.4 风险导致后果的原因2.3.5 风…

Prompt本质解密及Evaluation实战与源码解析(三)

9.5 Evaluation for QA源码解析 如图9-4所示,我们看一下LangChain框架对问答评估的(Evaluation for QA)的源代码。 图9- 5 LangChain的evaluation qa目录 在eval_prompt.py文件里面,主要定义了三个类 PromptTemplate,它们都是用于生成题目的模板。 Gavin大咖微信:NLP_Mat…

跨端技术栈综合考察:深入剖析 UniApp、Flutter、Taro 和 React Native 的优势与限制

文章目录 📈UniApp⚡概念⚡优势⚡限制 📈Flutter⚡概念⚡优势⚡限制 📈Taro⚡概念⚡优势⚡限制 📈React Native⚡概念⚡优势⚡限制 📈跨端技术栈对比附录:「简历必备」前后端实战项目(推荐&…

强化学习快速复习笔记--待更新

目录 蒙特卡洛方法动态规划算法策略迭代 时序差分方法Sarsa算法Q-learning算法如何区分在线学习和离线学习DQN深度强化Q学习概念介绍代码解析 DQN改进算法Double DQN网络 蒙特卡洛方法 求解价值函数和状态价值函数,可以使用蒙特卡洛方法和动态规划。首先介绍一下蒙…

25-分布式事务----Seate

1、seate 官网:Seata Seata 是一款开源的分布式事务解决方案,致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式,为用户打造一站式的分布式解决方案。 1.1、Seata术语 TC (Transaction Coordinator) - 事务协调者…

mysql 执行sql开启事务

SHOW VARIABLES LIKE autocommit;SET autocommit 0; INSERT INTO sugar.realmauctiondatum(Id, Name) VALUES (3, A); INSERT INTO sugar.realmauctiondatum(Id, Name) VALUES (1, A); COMMIT;如果没有调用COMMIT;退出session时会执行回滚

python 面向对象之继承

文章目录 前言继承的概念单继承多继承子类重写父类的同名方法和属性子类调用父类同名的方法和属性多层继承私有权限 前言 前面我们已经学习了 python 面向对象的类和对象,那么今天我将为大家分享面向对象的三大特性之一:继承。 继承具有以下特性&#…

怎么使用文件高速传输,推荐镭速高速文件传输解决方案

​​随着互联网的发展,文件传输越来越频繁,如何实现文件高速传输已经越来越成为企业发展过程中需要解决的问题,在当今的业务中,随着与客户和供应商以及内部系统的所有通信的数据量不断增加,对 高速文件传输解决方案的需…

全网最新项目:会说话的汤姆猫直播搭建教程(附教学流程)

今天为大家分享一个 汤姆猫直播搭建项目 ,这个项目最近可以说在圈内爆火,我相信很多朋友以前应该都玩过,或者说给自己家小孩子玩过。 -------------------------------------------------------------------- 课程获取:www.yn521.cn/160852…

RabbitMQ【笔记整理+代码案例】

1. 消息队列 1.1. MQ 的相关概念 1.1.1. 什么是 MQ MQ(message queue),从字面意思上看,本质是个队列,FIFO 先入先出,只不过队列中存放的内容是message 而已,还是一种跨进程的通信机制,用于上下游传递消息…