pytorch的学习与总结(第二次组会)

news2025/1/23 3:05:39

pytorch的学习与总结

  • 一、pytorch的基础学习
    • 1.1 dataset与dataloader
    • 1.2 可视化工具(tensorboard)、数据转换工具(transforms)
    • 1.3 卷积、池化、线性层、激活函数
    • 1.4 损失函数、反向传播、优化器
    • 1.5 模型的保存、加载、修改
  • 二、 pytorch分类项目实现
    • 2.1 网络模型
    • 2.2 具体代码

一、pytorch的基础学习

1.1 dataset与dataloader

在这里插入图片描述

1.2 可视化工具(tensorboard)、数据转换工具(transforms)

在这里插入图片描述

1.3 卷积、池化、线性层、激活函数

在这里插入图片描述

1.4 损失函数、反向传播、优化器

在这里插入图片描述

1.5 模型的保存、加载、修改

在这里插入图片描述

二、 pytorch分类项目实现

2.1 网络模型

在这里插入图片描述

2.2 具体代码

  1. model与train
# Seven
import torch
import torchvision         # 导入数据集
from torch.utils.tensorboard import SummaryWriter    # 用于日志



# 准备数据集
from torch import nn
from torch.utils.data import DataLoader

# 通过GPU训练还是通过CPU训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 导入训练数据集合
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 导入测试数据集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)


# 查看训练集合和测试集合长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 输出查看
print("训练数据集合的长度为:{}".format(train_data_size))
print("测试数据集合的长度为:{}".format(test_data_size))


# 利用DataLoader 来加载数据集
train_dataloader = DataLoader(dataset=train_data, batch_size=64)
test_dataloader = DataLoader(dataset=test_data, batch_size=64)

# 创建网络模型
class Seven(nn.Module):
    def __init__(self):
        # 初始化
        super(Seven, self).__init__()

        # 设置网络模型
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2), # 卷积层
            nn.MaxPool2d(2),       # 池化层
            nn.Conv2d(32,32,5,1,2),  # 卷积层
            nn.MaxPool2d(2),       # 池化层
            nn.Conv2d(32,64,5,1,2),  # 卷积层
            nn.MaxPool2d(2),        # 池化层
            nn.Flatten(),    # 展开
            nn.Linear(64*4*4, 64),  # 线性层
            nn.Linear(64, 10)   # 线性层输出
        )


    def forward(self, x):
        x = self.model(x)
        return x

seven = Seven()

#gpu
seven.to(device)

# 损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()

# 如果是在gpu上跑
loss_fn.to(device)

# 优化器,设置学习率
learning_rate = 1e-2
optimizer = torch.optim.SGD(seven.parameters(), lr=learning_rate)

# 设置网络模型中的一些参数
total_train_step = 0   # 记录训练的次数
total_test_step = 0    # 记录测试的次数
epochs = 10   # 数据扫描次数

# 添加tensorboard  (用于记录log)
writer = SummaryWriter("./logs_train")

for epoch in range(epochs):
    print("第{}轮测试开始".format(epoch+1))

    # 训练步骤开始
    seven.train()   # 标注模型为训练状态,网络模型中例如dropout类似的层会生效
    for data in train_dataloader:
        # 因为我们封装为dataloader,会设置一个批次
        # 这里就会迭代的时候以一个批次大小,将图片与标签返回
        imgs, targets = data

        # gpu
        imgs = imgs.to(device)
        targets = targets.to(device)

        # 把数据丢进模型
        outputs = seven(imgs)

        # 调用损失函数, 看预测值与目标值之间的差距
        loss = loss_fn(outputs, targets)

        # 优化器梯度清零,因为pytorch不会主动给梯度清零
        optimizer.zero_grad()
        # 调用反向传播(其实就是对损失函数求导)
        loss.backward()
        # 进行参数更新
        optimizer.step()

        # 训练个数+1
        total_train_step += 1
        if total_train_step % 100 == 0:
            # loss.item(): 拿到具体的损失值
            print("训练次数:{},loss:{}".format(total_train_step, loss.item()))
            # 每100次记录一次损失
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    seven.eval()  # 模型开始测试模式,网络模型中的dropout类似在寻楼层有效的层失效
    total_test_loss = 0  # 记录测试的损失值
    total_accuracy = 0   # 记录预测正确的次数
    with torch.no_grad(): # 下面的内容不进行求梯度与反向传播,
        for data in test_dataloader:
            imgs, targets = data
            # gpu
            imgs = imgs.to(device)
            targets = targets.to(device)


            outputs = seven(imgs)
            loss = loss_fn(outputs, targets)
            # 记录损失值
            total_test_loss = total_test_loss + loss.item()
            # argmax(1) 横向取出预测最大值的下标,因为是分类预测,所以预测的最大概率,就是预测的标签
            # 这里会出现[False,True]这类类型,通过sum()记录这一批量预测值与标签相等的总数
            accuracy = (outputs.argmax(1) == targets).sum()
            # 记录预测成功的总数
            total_accuracy = total_accuracy + accuracy

    print("整体数据集的Loss:{}".format(total_test_loss))
    print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
    # 记录一次扫描的损失
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    # 记录一次扫描的准确率
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)

    total_test_step +=1


    # 每次都保存一次模型
    torch.save(seven, "seven_{}.pth".format(epoch))
    print("第{}次模型保存".format(epoch))


writer.close()
  1. test
import torch
import torchvision
from PIL import Image
from torch import nn

# 设置要预测图片的路径
image_path = "./imgs/feiji.jpg"
# 打开图片
image = Image.open(image_path)
print(image)

# 换大小并转换格式为tensor类型
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)


# 搭建神经网络
class Seven(nn.Module):
    def __init__(self):
        # 初始化
        super(Seven, self).__init__()

        # 设置网络模型
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2), # 卷积层
            nn.MaxPool2d(2),       # 池化层
            nn.Conv2d(32,32,5,1,2),  # 卷积层
            nn.MaxPool2d(2),       # 池化层
            nn.Conv2d(32,64,5,1,2),  # 卷积层
            nn.MaxPool2d(2),        # 池化层
            nn.Flatten(),    # 展开
            nn.Linear(64*4*4, 64),  # 线性层
            nn.Linear(64, 10)   # 线性层输出
        )


    def forward(self, x):
        x = self.model(x)
        return x


# 读取训练的模型
# 当训练时以gpu训练时,但是测试用cpu时需要声明一下读取的模型
model = torch.load("seven_9.pth", map_location=torch.device('cpu'))

print(model)
image = torch.reshape(image, (1,3,32,32))

# 测试模型
model.eval()
with torch.no_grad(): # 声明不计算梯度
    output = model(image)

print(output)
# 选择横向最大的概率的下标
print(output.argmax(1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/555763.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

新星计划2023【《计算之魂》读书会】学习方向报名入口!

前排提醒:这里是新星计划2023【《计算之魂》读书会】学习方向的报名入口,一经报名,不可更换。 ↓↓↓报名方式:(下滑到本页面底部) 一、关于本学习方向导师 博客昵称:异步社区博客主页&#x…

AI大模型时代,云从科技携“从容大模型”入场如何“从容”?

5月18日,在“AI赋能数字中国产业论坛暨2023云从科技人机协同发布会”上,云从科技自研“从容大模型”正式亮相。 根据发布会信息,“从容大模型”具备问答、阅读理解、文学创作以及解题方面的能力。受发布会消息影响,5月18日午间休盘…

【libdatachannel】cmake+vs2022 构建

libdatachannel libdatachannel 是基于c++17实现的构建 OpenSSL 找不到 Selecting Windows SDK version 10.0.22000.0 to target Windows 10.0.22621. The CXX compiler identification is MSVC 19.35.32217.1 Detecting CXX compiler ABI info Detecting CXX compiler ABI inf…

利用GPIO线进行板间通信-23-5-22

本项目基于VU9P(xcvu9pflga2105)板卡以及ZYNQ(xc7z015clg485) 简单结构流程介绍: 1.上位机通过千兆网将指令下发到ZYNQ,ZYNQ进行解帧,将数据解析出来后存储到RAM中,RAM将数据不断输送给GPIO模块,GPIO模块根据对应地址输出数据是…

新来的00后实习生太牛了,已经被取代了.....

前几天有个朋友向我哭诉,说她在公司工作(软件测试)了7年了,却被一个00后实习生代替了,该何去何从? 这是一个值得深思的问题,作为职场人员,我们确实该思考,我们的工作会被…

1718_Linux命令模式下查看日历

全部学习汇总: GreyZhang/bash_basic: my learning note about bash shell. (github.com) 前面发布了一份学习笔记,涉嫌过渡宣传,虽然我也没搞懂为什么。有一系列修改建议,我觉得直接放弃了。还是发一份新的吧! Linux命…

【数据结构】哈希底层结构

目录 一、哈希概念 二、哈希实现 1、闭散列 1.1、线性探测 1.2、二次探测 2、开散列 2.1、开散列的概念 2.2、开散列的结构 2.3、开散列的查找 2.4、开散列的插入 2.5、开散列的删除 3、性能分析 一、哈希概念 顺序结构以及平衡树中,元素关键码与其存储位…

如何用Postman做接口自动化测试?

本文适合已经掌握 Postman 基本用法的读者,即对接口相关概念有一定了解、已经会使用 Postman 进行模拟请求等基本操作。 工作环境与版本: Window 7(64位)Postman (Chrome App v5.5.3) P.S. 不同版本页面 U…

JAVA—实验4 继承、接口与多态

一、实验目的 掌握类的继承机制掌握接口的定义方法熟悉成员方法或构造方法多态性 二、实验内容 1.卖车-接口与多态编程 【问题描述】 (1) 汽车接口(Car):有两个方法, getName()、getPrice()(接口源文件可以自己写,也…

2024总统大选,成为“关乎比特币未来的公投”?背后是怎样的政治抱负?

在今年的迈阿密比特币大会上,Robert F.Kennedy Jr和Vivek Ramaswamy相继发布声明表示,他们将在2024年初选前接受比特币(BTC)的捐款。 RFK Jr作为美国前总统约翰肯尼迪的侄子,是第一个公开接受Crypto的总统候选人&#…

chatgpt赋能Python-pythons_9_98_987

用Python计算s998987的方法及重要性分析 介绍 Python是一种开源的高级编程语言,它被广泛应用于数据处理、web开发和人工智能等领域。它的简洁、易读易写的语法使得很多程序员喜爱使用它来完成各种工作。本文将介绍如何用Python计算一个简单的数学表达式&#xff1…

微服务基础环境搭建--和创建公用模块

目录 微服务基础环境搭建 创建父工程,用于聚合其它微服务模块 创建父项目, 作为聚合其它微服务模块 项目设置​编辑 ​编辑 删除src, 保留一个纯净环境​编辑 1. 配置父工程pom.xml, 作为聚合其它模块 2、修改e-commerce-center\pom.xml,删除不需要的配置节…

Java.lang.NoClassDefFoundError: org/apache/logging/log4j/util/ReflectionUtil

具体问题描述如下: SLF4J: Class path contains multiple SLF4J bindings. SLF4J: Found binding in [jar:file:/D:/maven/repository/org/apache/logging/log4j/log4j-slf4j-impl/2.6.2/log4j-slf4j-impl-2.6.2.jar!/org/slf4j/impl/StaticLoggerBinder.class] SL…

【Spring - beans】 BeanDefinition 源码

目录 1. BeanDefinition 1.1 AbstractBeanDefinition 1.2 RootBeanDefinition 1.3 ChildBeanDefinition 1.4 GenericBeanDefinition 2. BeanDefinitionReader 2.1 AbstractBeanDefinitionReader 2.2 XmlBeanDefinitionReader 2.3 GroovyBeanDefinitionReader 2.4 Pro…

(跨模态)AI作画——使用stable-diffusion生成图片

AI作画——使用stable-diffusion生成图片 0. 简介1. 注册并登录huggingface2. 下载模型3. 生成 0. 简介 自从DallE问世以来,AI绘画越来越收到关注,从最初只能画出某些特征,到越来越逼近真实图片,并且可以利用prompt来指导生成图片…

软件测试面试题——数据库知识

1、要查询每个商品的入库数量,可以使用以下SQL语句: SELECT 商品编号, SUM(入库数量) AS 入库数量 FROM Stock GROUP BY 商品编号;这将从Stock表中选择每个商品的入库数量,并使用SUM函数对入库数量进行求和。结果将按照商品编号进行分组&…

数据宝藏与精灵法师:探秘Elf擦除魔法的奇幻故事

在数字领域的奇幻王国中,大家视数据为宝藏。作为奇幻王国的国王,在他的宝库中,自然是有着无数的数据宝藏。这么多的数据宝藏,却让国王发难了。因为宝库有限,放不下这么多数据宝藏。因此,国王广招天下的精灵…

【备战秋招】每日一题:3月18日美团春招第三题:题面+题目思路 + C++/python/js/Go/java带注释

2023大厂笔试模拟练习网站(含题解) www.codefun2000.com 最近我们一直在将收集到的各种大厂笔试的解题思路还原成题目并制作数据,挂载到我们的OJ上,供大家学习交流,体会笔试难度。现已录入200道互联网大厂模拟练习题&…

简易someip服务发现SD报文演示

环境 $ cat /etc/os-release PRETTY_NAME"Ubuntu 22.04.1 LTS" NAME"Ubuntu" VERSION_ID"22.04" VERSION"22.04.1 LTS (Jammy Jellyfish)" VERSION_CODENAMEjammy IDubuntu ID_LIKEdebian HOME_URL"https://www.ubuntu.com/"…

chatgpt赋能Python-pythonsum

Pythonsum:优秀的Python算法包介绍 Pythonsum是Python语言的一个优秀的算法包,具有很高的可重用性和性能,支持大规模数据处理和复杂算法实现。本文将为大家介绍Pythonsum的基本功能和优势。 Pythonsum的基本功能 Pythonsum提供了一系列丰富…