PyTorch 深度学习 || 专题二:PyTorch 编程基础

news2024/9/30 11:30:39

PyTorch 编程基础

文章目录

  • PyTorch 编程基础
    • 1. backword 求梯度
    • 2. 常用损失函数
      • 2.1 均方误差损失函数
      • 2.2 L1范数误差损失函数
      • 2.3 交叉熵损失函数
    • 3. 优化器

1. backword 求梯度

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w) 
b = torch.add(w, 1)
y = torch.mul(a, b) # y=(x+w)(w+1)
y.backward() # 分别求出两个自变量的导数

print(w.grad) # (w+1)+ (x+w) = x+2w+1 = 5
print(x.grad) # w+1 = 2

tensor([5.])

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
for i in range(3):
    a = torch.add(x, w)
    b = torch.add(w, 1)
    y = torch.mul(a, b) # y=(x+w)(w+1)
    y.backward() # (w+1)+(x+w) = x+2w+1 = 5
    print(w.grad) # 梯度在循环过程中进行了累加

tensor([5.])
tensor([10.])
tensor([15.])

2. 常用损失函数

2.1 均方误差损失函数

loss ( x , y ) = 1 n ∥ x − y ∥ 2 2 = 1 n ∑ i = 1 n ( x i − y i ) 2 \text{loss}(\boldsymbol{x},\boldsymbol{y})=\frac{1}{n}\Vert\boldsymbol{x}-\boldsymbol{y}\Vert_2^2=\frac{1}{n}\sum_{i=1}^n(x_i-y_i)^2 loss(x,y)=n1xy22=n1i=1n(xiyi)2

import torch

input = torch.tensor([1.0, 2.0, 3.0, 4.0])
target = torch.tensor([4.0, 5.0, 6.0, 7.0])

loss_fn = torch.nn.MSELoss(reduction='mean')
loss = loss_fn(input, target)
print(loss)

tensor(9.)

2.2 L1范数误差损失函数

loss ( x , y ) = 1 n ∥ x − y ∥ 1 = 1 n ∑ i = 1 n ∣ x i − y i ∣ \text{loss}(\boldsymbol{x},\boldsymbol{y})=\frac{1}{n}\Vert\boldsymbol{x}-\boldsymbol{y}\Vert_1=\frac{1}{n}\sum_{i=1}^n\vert x_i-y_i\vert loss(x,y)=n1xy1=n1i=1nxiyi

import torch

loss = torch.nn.L1Loss(reduction='mean')
input = torch.tensor([1.0, 2.0, 3.0, 4.0])
target = torch.tensor([4.0, 5.0, 6.0, 7.0])
output = loss(input, target)
print(output)

tensor(3.)

2.3 交叉熵损失函数

h ( p , q ) = − ∑ x n p ( x ) ∗ log ⁡ q ( x ) h(p,q)=-\sum_{x}^np( x)*\log q(x) h(p,q)=xnp(x)logq(x)

import torch

entroy = torch.nn.CrossEntropyLoss()
input = torch.Tensor([[-0.1181, -0.3682, -0.2209]])
target = torch.tensor([0])

output = entroy(input, target)
print(output)

tensor(0.9862)

3. 优化器

import torch
import torch.nn
import torch.utils.data as Data
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

matplotlib.rcParams['font.sans-serif'] = ['SimHei']

#准备建模数据
x = torch.unsqueeze(torch.linspace(-1, 1, 500), dim=1)
y = x.pow(3)

#设置超参数
LR = 0.01
batch_size = 15
epoches = 5
torch.manual_seed(10)

#设置数据加载器
dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2)

#搭建神经网络
class Net(torch.nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden_layer = torch.nn.Linear(n_input, n_hidden)
        self.output_layer = torch.nn.Linear(n_hidden, n_output)

    def forward(self, input):
        x = torch.relu(self.hidden_layer(input))
        output = self.output_layer(x)
        return output

#训练模型并输出折线图
def train():
    net_SGD = Net(1, 10, 1)
    net_Momentum = Net(1, 10, 1)
    net_AdaGrad = Net(1, 10, 1)
    net_RMSprop = Net(1, 10, 1)
    net_Adam = Net(1, 10, 1)
    nets = [net_SGD, net_Momentum, net_AdaGrad, net_RMSprop, net_Adam]

    #定义优化器
    optimizer_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
    optimizer_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.6)
    optimizer_AdaGrad = torch.optim.Adagrad(net_AdaGrad.parameters(), lr=LR, lr_decay=0)
    optimizer_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
    optimizer_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
    optimizers = [optimizer_SGD, optimizer_Momentum, optimizer_AdaGrad, optimizer_RMSprop, optimizer_Adam]

    #定义损失函数
    loss_function = torch.nn.MSELoss()
    losses = [[], [], [], [], []]

    for epoch in range(epoches):
        for step, (batch_x, batch_y) in enumerate(loader):
            for net, optimizer, loss_list in zip(nets, optimizers, losses):
                pred_y = net(batch_x)
                loss = loss_function(pred_y, batch_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_list.append(loss.data.numpy())

    plt.figure(figsize=(12,7))
    labels = ['SGD', 'Momentum', 'AdaGrad', 'RMSprop', 'Adam']
    for i, loss in enumerate(losses):
        plt.plot(loss, label=labels[i])
    plt.legend(loc='upper right',fontsize=15)
    plt.tick_params(labelsize=13)
    plt.xlabel('Train Step',size=15)
    plt.ylabel('Loss',size=15)
    plt.ylim((0, 0.3))
    plt.show()

if __name__ == "__main__":
    train()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/585680.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

R实践——【rgplates】安装、介绍、入门

【rgplates】安装、介绍、入门 1. rgplates 安装1.1 easy way1.2 备案方法 2. rgplates 介绍3. rgplates 在线方法入门3.1 加载rgplates3.2 板块重建3.3 独立的地点坐标3.3.1 单个现存坐标点3.3.2 单个点的古坐标3.3.3 多个点的古坐标 3.4 现今的海岸线3.5 其他的重建模型3.6 在…

JMeter 性能测试基本过程及示例

jmeter 为性能测试提供了一下特色: 2023年最新出炉性能测试教程,真实企业性能压测全流程项目实战训练大合集!_哔哩哔哩_bilibili2023年最新出炉性能测试教程,真实企业性能压测全流程项目实战训练大合集!共计11条视频&…

javascript获取对象的键名列表、键值列表

Object.keys&#xff1a;获取对象的键名列表 Object.values&#xff1a;获取对象的键值列表 <script>var obj {name: 1,age: 2,order: 3}const klist Object.keys(obj)const vals Object.values(obj)console.log(obj, obj)console.log(键名列表, klist)console.log(键…

STM32F4_位带操作

目录 1. 位带简介 2. 别名区地址的计算 2.1 合并计算 3. 位带操作访问ODR和IDR寄存器 4. GPIOB->MODER&~(3<<(9*2));GPIOB->MODER|0<<9*2 / GPIOB->MODER&~(3<<(9*2));GPIOB->MODER|1<<9*2 位带操作在写单片机程序时&#xf…

springboot+vue 刘老师

课程内容 前端&#xff1a;vue elementui 后端&#xff1a;springboot mybatisplus 公共云部署 ------boot-------- 热部署 不用devtools&#xff0c;交给jrebel工具 RequestMapping ​ 参数 value 路径 method 方法consumes 请求媒体类型 如 application/jsonproduces …

DJ5-7 缓冲区管理

目录 5.7.1 缓冲的引入 5.7.2 单缓冲和双缓冲 1、单缓冲&#xff08;Single Buffer&#xff09; 2、双缓冲&#xff08;Double Buffer&#xff09; 3、双机通信时缓冲区的设置 5.7.3 循环缓冲 1、循环缓冲的组成 2、循环缓冲区的使用 3、进程同步 5.7.4 缓冲池 …

Spring Security源码剖析从入门到精通.跟学尚硅谷

1.1 概要 Spring 是非常流行和成功的 Java 应用开发框架&#xff0c;Spring Security 正是 Spring 家族中的成员。Spring Security 基于 Spring 框架&#xff0c;提供了一套 Web 应用安全性的完整解决方案。 正如你可能知道的关于安全方面的两个主要区域是“认证”和“授权”…

Mediapipe人体识别库

一、简介 官网&#xff1a;MediaPipe | Google for Developershttps://developers.google.cn/mediapipe Mediapipe 是2012年起开始公司内部使用&#xff0c;2019年google的一个开源项目&#xff0c;可以提供开源的、跨平台的常用机器学习(machine learning)方案。Mediapipe…

python-sqlite3使用指南

python下sqlite3使用指南 文章目录 python下sqlite3使用指南开发环境sqlite3常用APICRUD实例参考 开发环境 vscode ​ 开发语言&#xff1a; python vscode SQLite插件使用方法&#xff1a; 之后在这里就可以发现可视化数据&#xff1a; sqlite3常用API Python 2.5.x 以上…

信息安全实践1.3(HTTPS)

前言 做这个实验对Tomcat的版本有要求&#xff0c;最好是使用Tomcat8。因为我之前使用Tomcat10&#xff0c;然后一直做不出来。 要求 部署Web服务器端HTTPS功能&#xff0c;通过网络嗅探分析HTTPS通过SSL实施安全保护的效果 关键步骤 首先要给tomcat配置https&#xff0c;也…

设计模式之美-实战一(上):业务开发常用的基于贫血模型的MVC架构违背OOP吗?

领域驱动设计&#xff08;Domain Driven Design&#xff0c;简称DDD&#xff09;盛行之后&#xff0c;这种基于贫血模型的传统的开发模式就更加被人诟病。而基于充血模型的DDD开发模式越来越被人提倡。所以&#xff0c;我打算用两节课的时间&#xff0c;结合一个虚拟钱包系统的…

超低功耗三通道低频无线唤醒ASK接收 125k soc芯片UM2082F08

UM2082F08 是基于单周期 8051 内核的超低功耗 8 位、、具有三通道低频无线唤醒 ASK 接收功能的 SOC 芯片。芯片可检测 30KHz~300KHz 范围的 LF&#xff08;低频&#xff09;载波频率数据并触发唤醒信号&#xff0c;同时可以调节接收灵敏度&#xff0c;确保在各种应用环境下实现…

代码随想录算法训练营15期 Day 6 | 242.有效的字母异位词 、349. 两个数组的交集 、202. 快乐数、1. 两数之和

由于昨天是周日&#xff0c;周日是休息日&#xff0c;所以就是什么也没有写啦。今天是day06天&#xff0c;继续加油。 哈希表理论基础 建议&#xff1a;大家要了解哈希表的内部实现原理&#xff0c;哈希函数&#xff0c;哈希碰撞&#xff0c;以及常见哈希表的区别&#xff0c;…

Toolkit.getDefaultToolkit()获得的java.awt.Toolkit是不是同一个? 是否为单例设计模式?答案是**是**

Toolkit.getDefaultToolkit()获得的java.awt.Toolkit是不是同一个? 是否为单例设计模式? 答案是是 反复调用Toolkit.getDefaultToolkit()获得的 java.awt.Toolkit 是同一个 import java.awt.Toolkit;public class GetDefaultToolkit是不是获得单例Toolkit {static public …

【P43】JMeter 吞吐量控制器(Throughput Controller)

文章目录 一、吞吐量控制器&#xff08;Throughput Controller&#xff09;参数说明二、测试计划设计2.1、Total Executions2.2、Percent Executions2.3、Per User 一、吞吐量控制器&#xff08;Throughput Controller&#xff09;参数说明 允许用户控制后代元素的执行的次数。…

中级软件设计师考试总结

目录 前言考前学习宏观什么是软考涉及的知识范围软考整体导图总结 微观我的分享——希尔排序学习过程结构化做题 考试阶段确定不确定 考后总结 前言 作为一名中级软件设计师&#xff0c;考试是衡量自己技能和水平的一项重要指标。在备考和考试过程中&#xff0c;我通过总结经验…

【TI毫米波雷达笔记】IWR6843AOPEVM-G的DCA1000EVM模式配置及避坑(官方手册有误)

【TI毫米波雷达笔记】IWR6843AOPEVM-G的DCA1000EVM模式配置及避坑&#xff08;官方手册有误&#xff09; IWR6843AOPEVM-G版本可以直接与DCA1000EVM连接 进行数据获取 不需要连接MMWAVEICBOOST版 直接使用 DCA1000mmWave Studio 软件进行数据采集 在官方手册中 User’s Guide…

linux环境下安装gitlab

前几天跟朋友聊天时说到gitlab版本控制。其实&#xff0c;之前也对它只是知道有这个东西&#xff0c;也会用。只是对于它的安装和配置&#xff0c;那我还是没整过。这两天&#xff0c;我找了一下网上的资料&#xff0c;还是写下吧。 一安装&#xff1a; 按网上所说&#xff0c;…

2023年上半年信息系统项目管理师下午真题及答案解析

试题一(25分) 为实现空气质量的精细化治理&#xff0c;某市规划了智慧环保项目。该项目涉及网格化监测、应急管理、执法系统等多个子系统。作为总集成商&#xff0c;A公司非常重视&#xff0c;委派李经理任项目经理&#xff0c;对公司内研发部门与项目相关的各产品线研发人员及…

带你开发一个远程控制项目---->STM32+标准库+阿里云平台+传感器模块+远程显示-------之 阿里云平台项目建造。

第一篇章&#xff1a; (13条消息) 带你开发一个远程控制项目----&#xff1e;STM32标准库阿里云平台传感器模块远程显示。_海口飞鹏岛科技有限公司的博客-CSDN博客 本次文章是指引开发者进行开发阿里云平台建造设备项目&#xff0c;可观看UP主教程&#xff0c;完成如下&#x…