Pytorch(一)

news2024/11/18 7:27:09

目录

一、基本操作

二、自动求导机制

 三、线性回归DEMO

3.1模型的读取与保存

3.2利用GPU训练时

四、常见的Tensor形式

五、Hub模块


一、基本操作

操作代码如下:

import torch
import numpy as np

#创建一个矩阵
x1 = torch.empty(5,3)

# 随机值
x2 = torch.rand(5,3)

# 初始化一个全零的矩阵
x3 = torch.zeros(5,3,dtype = torch.long)

# view操作改变矩阵维度
x4 = torch.randn(4,4) #4*4矩阵
y = x4.view(16) #变成一行的矩阵
z = x4.view(-1,8) #变为2*8的矩阵
print(y.size()) #矩阵的尺寸

#与numpy的协同操作
# tensor转array
a = torch.ones(5)
b = a.numpy()

# array转tensor
a1 = np.ones(5)
b1 = torch.from_numpy(a)

二、自动求导机制

案例代码如下:

 

import torch

#计算流程
x = torch.rand(1)
b = torch.rand(1,requires_grad=True)
w = torch.rand(1,requires_grad=True)
y = w * x
z = y + b

# 反向传播计算
z.backward(retain_graph = True)
print(w.grad)
print(b.grad)

 三、线性回归DEMO

 

import numpy as np
import torch
import torch.nn as nn

# 构建线性回归模型
class LinearRegressionModel(nn.Module):
    def __init__(self,input_dim,output_dim):
        super(LinearRegressionModel,self).__init__()
        self.linear = nn.Linear(input_dim,output_dim)

    def forward(self,x):
        out = self.linear(x)
        return out

x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)
print(x_train.shape)

#y = 2x + 1
y_values = [2*i + 1 for i in range(11)]
y_train = np.array(x_values,dtype=np.float32)
y_train = x_train.reshape(-1,1)

# 构建model
input_dim = 1
output_dim = 1

model = LinearRegressionModel(input_dim,output_dim)

# 指定好参数和损失函数
epochs = 1000 #训练次数
learning_rate = 0.01 #学习率
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate) #优化器
criterion = nn.MSELoss() #损失函数

# 训练模型
for epoch in range(epochs):
    epoch += 1
    #注意转行为tensor
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)


    #梯度要清零每一次迭代
    optimizer.zero_grad()

    #前向传播
    outputs = model(inputs)

    #计算损失
    loss = criterion(outputs,labels)

    #反向传播
    loss.backward()

    #更新权重参数
    optimizer.step()
    if epoch % 50 ==0:
        print('epoch {},loss {}'.format(epoch,loss.item()))

3.1模型的读取与保存

# 模型的保存与读取
torch.save(model.state.dict(),'model.pkl') #保存
model.load_state_dict(torch.load('model.pkl')) #读取

3.2利用GPU训练时

利用GPU训练时要将数据与模型导入cuda

#注意转行为tensor
inputs = torch.from_numpy(x_train)
labels = torch.from_numpy(y_train)
#利用GPU训练数据时的数据
inputs = torch.from_numpy(x_train).to(device)
labels = torch.from_numpy(y_train).to(device)


model = LinearRegressionModel(input_dim,output_dim)

#使用GPU进行训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

四、常见的Tensor形式

  • 1.scalar:通常是指一个数值
  • 2.vector:通常是指一个向量
  • 3.matrix:通常是指一个数据矩阵
  • 4.n-dimensional tensor:高维数据

五、Hub模块

Github地址:https://github.com/pytorch/hub

Hub已有模型:https://pytorch.org/hub/research-models

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/803963.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

grid网格布局看这一篇就够了(接近3w字的总结)

在当今现代Web设计中,如何实现有效的布局一直是一个关键问题。这就是为什么CSS3推出了“grid网格布局”作为一种新的布局方式。使用grid,您可以轻松地设置复杂的网格布局,而无需使用冗长的CSS代码或框架。本文将探讨grid的概念、语法和实际应…

青大数据结构【2021】

一、单选(17!) 根据中序遍历得到降序序列可以知道,每个结点的左子树的结点的值比该结点的值小,因为没有重复的关键字,所以拥有最大值的结点没有左子树。 二、简答 三、分析计算 四、算法分析 3.迪杰斯特拉…

DAY53:动态规划(十八)最长公共子序列+不相交的线+最大子序列和

文章目录 1143.最长公共子序列(注意递推的逻辑)思路DP数组含义递推公式初始化完整版重要:该解法是否保持了元素顺序总结 1035.不相交的线(注意思路)思路完整版 53.最大子序列和思路1:贪心思路1完整版思路2:动态规划DP数…

java商城系统和php商城系统对比

java商城系统和php商城系统是两种常见的电子商务平台,它们都具有一定的优势和劣势。那么,java商城系统和php商城系统又有哪些差异呢? 一、开发难度 Java商城系统和PHP商城系统在开发难度方面存在一定的差异。Java商城系统需要使用Java语言进…

【前端工程化】未使用docker时,前端项目实现线上秒级回滚

目录 一. 前言 二. 思路 三. 实践 3.1 准备单页应用项目 3.2 保存历史构建index.html内容 3.3 模拟服务端托管前端应用 3.4 快速回滚node服务端代码开发 3.5 快速回滚前端可视化页面开发 3.6 快速回滚测试 四. 总结 一. 前言 项目快速回滚是前端工程化中很重要的一环&…

【项目】轻量级HTTP服务器

文章目录 一、项目介绍二、前置知识2.1 URI、URL、URN2.2 CGI2.2.1 CGI的概念2.2.2 CGI模式的实现2.2.3 CGI的意义 三、项目设计3.1 日志的编写3.2 套接字编写3.3 HTTP服务器实现3.4 HTTP请求与响应结构3.5 EndPoint类的实现3.5.1 EndPoint的基本逻辑3.5.2 读取请求3.5.3 构建响…

yolov5 onnx模型 转为 rknn模型

1、转换为rknn模型环境搭建 onnx模型需要转换为rknn模型才能在rv1126开发板上运行,所以需要先搭建转换环境 模型转换工具 模型转换相关文件下载: 网盘下载链接:百度网盘 请输入提取码 提取码:teuc 将其移动到虚拟机中&#xf…

用于提取数据的三个开源NLP工具

开发人员和数据科学家使用生成式AI和大语言模型(LLM)来查询大量文档和非结构化数据。开源LLM包括Dolly 2.0、EleutherAI Pythia、Meta AI LLaMa和StabilityLM等,它们都是尝试人工智能的起点,可以接受自然语言提示,生成…

3d动画用云渲染靠谱吗?有什么不同?

3d动画是一种利用计算机技术制作的动画形式,它可以模拟真实世界的物体和场景,创造出各种惊人的效果和视觉体验。3d动画广泛应用于影视、游戏、广告、教育等领域,成为当今最流行的艺术表现形式之一。据统计,2019年全球3d动画市场规…

[STL]list使用介绍

[STL]list使用 注:本文测试环境是visual studio2019。 文章目录 [STL]list使用1. list介绍2. 构造函数3. 迭代器相关函数begin函数和end函数rbegin函数和rend函数 4. 容量相关函数empty函数size函数 5. 数据修改函数push_back函数和pop_back函数push_front函数和pop…

软件兼容性测试的重要性以及一些常用的测试方法

随着软件应用的不断发展,不同操作系统、浏览器、设备和平台的广泛应用,软件兼容性变得越来越重要。在开发和发布软件之前进行兼容性测试是确保软件在多个环境下正常运行的关键步骤。本文将介绍软件兼容性测试的重要性以及一些常用的测试方法。 首先&…

JMeter常用内置对象:vars、ctx、prev

在前文 Beanshell Sampler 与 Beanshell 断言 中,初步阐述了JMeter beanshell的使用,接下来归集整理了JMeter beanshell 中常用的内置对象及其使用。 注:示例使用JMeter版本为5.1 1. vars 如 API 文档 所言,这是定义变量的类&a…

SpringBoot版本升级引起的FileNotFoundException——WebMvcConfigurerAdapter.class

缘起 最近公司项目要求JDK从8升到17,SpringBoot版本从2.x升级到3.x,期间遇到了一个诡异的FileNotFoundException异常,日志如下(敏感信息使用xxx脱敏) org.springframework.beans.factory.BeanDefinitionStoreExcepti…

安科瑞智能型BA系列电流传感器

安科瑞虞佳豪 壹捌柒陆壹伍玖玖零玖叁 选型

微信小程序——同一控件的点击与长按事件共存的解决方案

✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…

一份 GitHub star 过万的 1121 页图解算法让“他”成功杀进字节跳动

前两天收到读者喜报,说是进字节了,和他交流了一下他的学习心得,发现他看的资料也是我之前推荐过的算法进阶指南,这里推荐给大家,github star 可是过万哦!质量非常高! 这份算法笔记与其他的不同&…

使用andlua+写一个获取VSCode最新版本号的安卓软件

点击加号 选择Defalut模板 名称改为vscv 包名改为com.b.vscv 编辑main.lua require "import" import "android.app.*" import "android.os.*" import "android.widget.*" import "android.view.*" import "layout&qu…

微信小程序开发总结

架构分析 软件应用架构包括: 数据层、业务逻辑层、服务处、控制层、展示层、用户,小程序属于展示层,通常还需要其他层次提供支持 主体文件: app.js,app.json,app.wxss,前两者是必须存在再根目录下,app.wxs…

【网络云盘客户端】——上传文件的功能的实现

目录 上传文件功能的实现 uploadtask的设计 设置上传的槽函数 uploadFileAction接口 uploadFile接口 定时上传文件 进度条的设计 上传文件功能的实现 上传文件功能实现 1.双击 ”上传文件 “的 QListWidgetItem 或者 点击 “上传” 菜单项 都会弹出一个文件对话框 2.在文…

关于Java中的Lambda变量捕获

博主简介:想进大厂的打工人博主主页:xyk:所属专栏: JavaEE进阶 目录 一、Lambda表达式语法 二、Lambda中变量捕获 一、Lambda表达式语法 基本语法: (parameters) -> expression 或 (parameters) ->{ statements; } Lambda表达式由三部分组成&a…