pytorch再次学习

news2024/11/18 3:36:08

目录

  • 数据可视化
  • 切换设备device
  • 定义类
  • 打印每层的参数大小
  • 自动微分
    • 计算梯度
    • 禁用梯度追踪
    • 优化模型参数
  • 模型保存
  • 模型加载

数据可视化

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data1",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data1",
    train=False,
    download=True,
    transform=ToTensor()
)
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item() #用于随机取出一个training_data
    img, label = training_data[sample_idx]
    plt.subplot(3,3,i) #此处i必须是1开始
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

切换设备device

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

定义类

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

打印每层的参数大小

print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

自动微分

详见文章Variable
需要优化的参数需要加requires_grad=True,会计算这些参数对于loss的梯度

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

计算梯度

计算导数

loss.backward()
print(w.grad)
print(b.grad)

禁用梯度追踪

训练好后进行测试,也就是不要更新参数时使用

z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)

优化模型参数

  1. 调用optimizer.zero_grad()来重置模型参数的梯度。梯度会默认累加,为了防止重复计算(梯度),我们在每次遍历中显式的清空(梯度累加值)。
  2. 调用loss.backward()来反向传播预测误差。PyTorch对每个参数分别存储损失梯度。
  3. 我们获取到梯度后,调用optimizer.step()来根据反向传播中收集的梯度来调整参数。
optmizer.zero_grad()
loss.backward()
optmizer.step()

模型保存

import torch
import torchvision.models as models

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

模型加载

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/984655.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nginx中实现自签名SSL证书生成与配置

文章目录 一.相关介绍1.生成步骤2.相关名词介绍 二.Nginx中实现自签名SSL证书生成与配置1.私钥生成2.公钥生成3.生成解密的私钥key4.签名生成证书5.配置证书并验证6.登录 一.相关介绍 1.生成步骤 (1)生成私钥(Private Key)&…

elementUI——el-table自带排序使用问题

问题 排序表格默认第一列按降序排(状态1),当点击其他列后(状态2),改变日期,触发表格数据更新,发现列的排序还点亮在之前的操作上,没有按照默认来(回到状态1&a…

运筹系列85:求解大规模tsp问题的julia代码

1. 大规模tsp问题的挑战 数学模型和精确解法见《运筹系列65:TSP问题的精确求解法概述》和《运筹系列80:使用Julia精确求解tsp问题》: variable(m, x[1:n,1:n], Bin,Symmetric) # 0-1约束 objective(model, Min, sum(x.*distmat)/2) constraint(model, …

Linux——线程详解(一)

索引 初识线程1.inux下的线程2.再谈进程3.理解页表4. 再次理解虚拟到物理的转化 线程的控制1.线程的创建2.线程异常3.验证pthread_join 的第二个参数4.线程的退出方式5. 线程的公有和私有6.pthread_t 与线程独立栈7.线程的局部性存储8.线程分离 初识线程 1.inux下的线程 之前了…

通过RTSP协议接入RTSP流媒体服务器EasyNVR视频监控汇聚平台的设备显示离线是什么原因?

EasyNVR安防视频云服务是基于RTSP/Onvif协议接入的视频平台,可支持将接入的视频流进行全平台、全终端的分发,分发的视频流包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等。平台丰富灵活的视频能力,可应用在智慧校园、智慧工厂、智慧水利等…

028:vue上传解析excel文件,列表中输出内容

第028个 查看专栏目录: VUE ------ element UI 专栏目标 在vue和element UI联合技术栈的操控下,本专栏提供行之有效的源代码示例和信息点介绍,做到灵活运用。 (1)提供vue2的一些基本操作:安装、引用,模板使…

静态路由 网络实验

静态路由 网络实验 拓扑图初步配置R1 ip 配置R2 ip 配置R3 ip 配置查看当前的路由表信息查看路由表信息配置静态路由测试 拓扑图 需求:实现 ip 192.168.1.1 到 192.168.2.1 的通信。 初步配置 R1 ip 配置 system-view sysname R1 undo info-center enable # 忽略…

超图聚类论文阅读1:Kumar算法

超图聚类论文阅读1:Kumar算法 《超图中模块化的新度量:有效聚类的理论见解和启示》 《A New Measure of Modularity in Hypergraphs: Theoretical Insights and Implications for Effective Clustering》 COMPLEX NETWORKS 2020, SCI 3区 具体实现源码见…

【SWT】 Button 处理 Checkbox 按钮的选中与反选事件

介绍: 在使用 Java SWT(Standard Widget Toolkit)创建图形用户界面时,经常需要处理按钮的选中和反选事件。本文将介绍如何通过添加 SelectionListener 监听器来实现按钮选中与反选事件的处理,并相应地修改相关变量的值…

2023国赛数学建模B题思路分析 - 多波束测线问题

# 1 赛题 B 题 多波束测线问题 单波束测深是利用声波在水中的传播特性来测量水体深度的技术。声波在均匀介质中作匀 速直线传播, 在不同界面上产生反射, 利用这一原理,从测量船换能器垂直向海底发射声波信 号,并记录从声波发射到…

【MySQL系列】MySQL的事务管理的学习(一)_ 事务概念 | 事务操作方式 | 事务隔离级别

「前言」文章内容大致是MySQL事务管理。 「归属专栏」MySQL 「主页链接」个人主页 「笔者」枫叶先生(fy) 目录 一、事务概念二、事务的版本支持三、事务提交方式四、事务常见的操作方式4.1 事务正常操作4.2 事务异常验证 五、事务隔离级别5.1 查看与设置隔离性5.2 读未提交&…

flutter报错-cmdline-tools component is missing

安装完androidsdk和android studio后,打开控制台,出现错误 解决办法 找到自己安装android sdk的位置,然后安装上,并将下面的勾选上 再次运行 flutter doctor 不报错,出现以下画面 Doctor summary (to see all det…

视频融合平台EasyCVR综合管理平台加密机授权报错invalid character是什么原因

视频融合平台EasyCVR综合管理平台具备视频融合汇聚能力,作为安防视频监控综合管理平台,它支持多协议接入、多格式视频流分发,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包…

Java版 招投标系统简介 招投标系统源码 java招投标系统 招投标系统功能设计

项目说明 随着公司的快速发展,企业人员和经营规模不断壮大,公司对内部招采管理的提升提出了更高的要求。在企业里建立一个公平、公开、公正的采购环境,最大限度控制采购成本至关重要。符合国家电子招投标法律法规及相关规范,以及…

【pytorch】数据加载dataset和dataloader的使用

1、dataset加载数据集 dataset_tranform torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])train_set torchvision.datasets.CIFAR10(root"./train_dataset",trainTrue,transformdataset_tranform,downloadTrue) test_set torchvision.data…

高德地图,绘制矢量图形并获取经纬度

效果如图 我用的是AMapLoader这个地图插件,会省去很多配置的步骤,非常方便 首先下载插件,然后在局部引入 import AMapLoader from "amap/amap-jsapi-loader";然后在methods里面使用 // 打开地图弹窗mapShow() {this.innerVisible true;this.$nextTick(() > {…

祝贺!Databend Cloud 入驻 AWS 云市场

关于 Databend Cloud Databend Cloud 是基于开源云原生数仓项目 Databend 打造的一款易用、低成本、高性能的新一代大数据分析平台,提供一站式 SaaS 服务,免运维、开箱即用。 Databend Cloud 架构如下: 存储层完全面向对象存储而设计。 计算…

2023年海外推广怎么做?

答案是:2023海外推广可以选择谷歌SEO谷歌Ads双向运营。 理解当地文化 成功的海外推广首先是建立在对当地文化的深入了解和尊重的基础上。 本土化策略 为了更好地与当地用户互动,你的品牌、产品或服务需要与他们的文化和生活方式紧密相连。 例如&…

Linux/Windows中根据端口号关闭进程及关闭Java进程

目录 Linux 根据端口号关闭进程 关闭Java服务进程 Windows 根据端口号关闭进程 Linux 根据端口号关闭进程 第一步:根据端口号查询进程PID,可使用如下命令 netstat -anp | grep 8088(以8088端口号为例) 第二步:…

【大数据之Kafka】九、Kafka Broker之文件存储及高效读写数据

1 文件存储 1.1 文件存储机制 Topic是逻辑上的概念,而partition是物理上的概念,每个partition对应于一个log文件,该log文件中存储的是Producer生产的数据。 Producer生产的数据会被不断追加到该log文件末端,为防止log文件过大导致…