代码解释【待解决】

news2024/11/28 21:41:22

这里写目录标题

  • 代码解释
    • 数组转化为列表,方便在哪里
    • yeild
    • range()函数还有一些常用的小技巧。在这里我们列举两个常用技巧,以供参考
    • 梯度
      • l.sum().backward()的粗浅理解
      • detatch
        • 文字描述
        • 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值 x.grad.zero_()
      • detach作用
      • tensor.detach_()不同于detach()
    • 遍历而已return [text_labels[int(i)] for i in labels]
    • 待解决

代码解释

class Timer: #@save
"""记录多次运行时间"""
def __init__(self):
self.times = []
self.start()
def start(self):
"""启动计时器"""
self.tik = time.time()
def stop(self):
"""停止计时器并将时间记录在列表中"""
self.times.append(time.time() - self.tik)
return self.times[-1]
def avg(self):
"""返回平均时间"""
return sum(self.times) / len(self.times)
def sum(self):
"""返回时间总和"""
return sum(self.times)
def cumsum(self):
"""返回累计时间"""
return np.array(self.times).cumsum().tolist()
def cumsum(self):
"""返回累计时间"""
return np.array(self.times).cumsum().tolist()

数组转化为列表,方便在哪里

yeild

def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]

range()函数还有一些常用的小技巧。在这里我们列举两个常用技巧,以供参考

1、生成指定步长的整数序列
通过给range()函数指定步长,可以生成指定步长的整数序列。例如,下面的代码将生成0到9之间的偶数序列。

for i in range(0, 10, 2):
    print(i)

输出结果:

0
2
4
6
8

2、倒序遍历
通过指定开始与结束位置不同的参数,可以遍历一个倒序序列。例如,在一个列表中逆序遍历所有元素,可以使用以下方式:

arr = ['a', 'b', 'c', 'd', 'e']
for i in range(len(arr)-1, -1, -1):
    print(arr[i])

输出结果为:

e
d
c
b
a

梯度

>>> x = torch.ones(2, 2, requires_grad=True) # 2x2全为1的tensor
>>> y = x + 2
>>> z = y * y * 3
>>> out = z.mean()
>>> print(z, out)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)
>>> out.backward()
>>> print(x.grad)
tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

grad_fn=

x.grad.zero_()
y = x.sum()
y
# y.backward()
# x.grad

tensor(6., grad_fn=)

在这里插入图片描述

l.sum().backward()的粗浅理解

梯度下降的优化方法,
参数发生改变导致的损失函数loss值变化多少,就是梯度
在这里插入图片描述

在这里插入图片描述

梯度下降就是 遵循着公式,在损失函数递减的方向上更新权重和偏置
在这里插入图片描述

对梯度的基础概念了解不到位叭,爆炸!
梯度是 ,loss值对x求导?不准确。比如给定一个batch的样本输入,在每个样本点上loss值对样本输入值求导 组成的 list。一组batch的输入对应的loss值是一个值,可是梯度确是这个loss值对不同的x求导的导数 作为元素,组成 梯度

曾记否,滑滑梯的那张PPT,那个样本点的斜率,只是组成了梯度的一个元素而已
输入样本x是多个,对应的权重数组也是多个元素组成

那么梯度下降,粗略记忆是,loss值对w求导,这个没啥好说的
在这里插入图片描述

主要是,梯度 不是一个元素,是一个向量(但愿可以这样讲),是多维的
在损失函数递减的方向上更新权重和偏置,递减的方向 可不是某一个方向,而是对应的各个输入x的梯度方向

loss值是一个值,但是他是 多维的输入 通过多维的权重和偏置 sum在一起组成的
在代码里面求loss,得到的其实是个多维的向量,多维的输入(x1,x2,……) 操作得到的

# 训练
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10
net = linreg
loss = squred_loss

# 训练模型
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y) # 计算损失函数
        l.sum().backward() # 计算各个参数的梯度
        sgd([w, b], lr, batch_size) # 更新参数

detatch

import torch


def fun(x):
    return x * x;


a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = fun(a)

out.sum().backward()
print(a.grad)
'''返回:
None
tensor([2., 4., 6.])
'''

import torch


def fun(x):
    return x * x;


a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad) #None
out = fun(a)
print(out) # tensor([1., 4., 9.], grad_fn=<MulBackward0>)

# 添加detach(),c的requires_grad为False
c = out.detach()
print(c) # tensor([1., 4., 9.]),detatch之后断绝了c和a的关系,out和a是有关系的

# 这时候没有对c进行更改,所以并不会影响backward(),c更改也会影响out.sum().backward()
out.sum().backward()
print(a.grad) #tensor([2., 4., 6.])

# a.grad.zero_()
# c.sum().backward()
# print(a.grad) 会报错,c是被detach出了关系树,无求得c对a的导数
'''返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])
'''
文字描述

1、返回一个新的tensor,从当前计算图中分离下来。但是仍指向原变量的存放位置,不同之处只是requirse_grad为false.得到的这个tensir永远不需要计算器梯度,不具有grad.

在这里插入图片描述
从上可见tensor c是由out分离得到的,但是我也没有去改变这个c,这个时候依然对原来的out求导是不会有错误的,即

c,out之间的区别是c是没有梯度的,out是有梯度的

当使用detach()分离tensor,然后用这个分离出来的tensor去求导数,会影响backward(),会出现错误

2、使用detach返回的tensor和原始的tensor共同一个内存,即一个修改另一个也会跟着改变

当使用detach()分离tensor并且更改这个tensor时,即使再对原来的out求导数,会影响backward(),会出现错误

如果此时对c进行了更改,这个更改会被autograd追踪,在对out.sum()进行backward()时也会报错,因为此时的值进行backward()得到的梯度是错误的:

import torch
 
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
 
#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改
 
#会发现c的修改同时会影响out的值
print(c)
print(out)
 
#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)
 
'''返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    out.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified 
by an inplace operation
'''

1、返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad
2、在新的计算中,不会再对其值的梯度进行修改
3、如果对返回张量进行backwark()计算,会出现错误

在默认情况下,PyTorch会累积梯度,我们需要清除之前的值 x.grad.zero_()

要清除就输入一次:x.grad.zero_()

detach作用

当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播
假设有模型A和模型B,我们需要将A的输出作为B的输入,但训练时我们只训练模型B. 那么可以这样做:

input_B = output_A.detach()

它可以使两个计算图的梯度传递断开,从而实现我们所需的功能。
内容来源

tensor.detach_()不同于detach()

将一个tensor从创建它的图中分离,并把它设置成叶子tensor

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子tensor是x,但是这个时候对m进行了m.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

总结:其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的tensor

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的

但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

遍历而已return [text_labels[int(i)] for i in labels]

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

get_fashion_mnist_labels((0,2,3))
# get_fashion_mnist_labels([1,2,3]) 都一样
['t-shirt', 'pullover', 'dress']

或者简单地理解
Y=XW+b
X和W都是多维的,b或许也是?或许不是
.sum()函数主要有两个作用,一个是用来求和,一个是用来降维。而在这里是用到了降维的作用。
在这里插入图片描述

X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
lr = 3e-2  # Learning rate

for i in range(10):
    Y_hat = conv2d(X)
    l = (Y_hat - Y) ** 2
    conv2d.zero_grad()
    l.sum().backward()
    # Update the kernel
    conv2d.weight.data[:] -= lr * conv2d.weight.grad
    if (i + 1) % 2 == 0:
        print(f'epoch {i + 1}, loss {l.sum():.3f}')

print(conv2d.weight.data.reshape((1, 2)))  

待解决

Softmax回归的训练

def train_epoch_ch3(net, train_iter, loss, updater):  
    """训练模型一个迭代周期(定义见第3章)。"""
    if isinstance(net, torch.nn.Module): #如果是nn模具
        net.train() #开启训练模式
    metric = Accumulator(3) #长度为3的迭代器 来累积需要信息
    for X, y in train_iter: #扫描数据
        y_hat = net(X) #计算y_hat
        l = loss(y_hat, y) #损失函数计算l 
        if isinstance(updater, torch.optim.Optimizer):  #如果updater是pytorch的一个买者
            updater.zero_grad() #梯度设为0
            l.backward() #计算梯度
            updater.step() #更新参数
            metric.add( #样本数 累加数 正确的分类数 放到累加器里面
                float(l) * len(y), accuracy(y_hat, y), 
                y.size().numel())
        else:  #如果从头开始实现
            l.sum().backward() #l是一个向量 求和算梯度
            updater(X.shape[0]) 
            metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    return metric[0] / metric[2], metric[1] / metric[2] 
#返回结果:损失/样本总数,所有分类正确的样本数/ 总样本数

训练函数

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  
    """训练模型(定义见第3章)。"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc']) #可视化的animator(可忽略)
    for epoch in range(num_epochs): #扫描n遍数据
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater) #训练一次
        test_acc = evaluate_accuracy(net, test_iter) #在测试数据集上评估精度
        animator.add(epoch + 1, train_metrics + (test_acc,)) #显示
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1187836.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go cobra简介

当你需要为你的 Go 项目创建一个强大的命令行工具时&#xff0c;你可能会遇到许多挑战&#xff0c;比如如何定义命令、标志和参数&#xff0c;如何生成详细的帮助文档&#xff0c;如何支持子命令等等。为了解决这些问题&#xff0c;github.com/spf13/cobra 就可以派上用场。 g…

ESP32 C3 smartconfig一键配网报错

AP配网 在调试我的esp32c3的智能配网过程中&#xff0c;发现ap配网使用云智能App是可以正常配置的。 切记用户如果在menu菜单里使能AP配网&#xff0c;默认SSID名字为adh_PK值_MAC后6位。用户可以修改这个apssid的键值&#xff0c;但是要使用云智能app则这个名字的开头必须为ad…

asp.net外卖网站系统VS开发mysql数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net外卖网站系统 是一套完善的web设计管理系统&#xff0c;系统采用mvc模式&#xff08;BLLDALENTITY&#xff09;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为mysql&#xff0c;使用c#语…

不同访问修饰符的访问数据权限的区别

在Java中&#xff0c;有四种访问修饰符&#xff1a;public、private、protected和默认修饰符。它们的作用是控制类、变量和方法的可见性&#xff0c;也就是说它们控制了哪些代码可以访问某个类、变量或方法的数据成员。 public&#xff1a;可以被任何类访问&#xff0c;对外部…

【C++】手写堆

手写堆&#xff08;小顶堆&#xff09; 堆使用数组存储&#xff0c;下标从1开始&#xff08;下标从0开始也可以&#xff09;。 下标为u的节点&#xff1a; 左子节点下标为&#xff1a;2 * u&#xff08;下标从0开始&#xff0c;左子节点则为2 * i 1&#xff09;右子节点下标…

最大似然估计直观理解

目的 由于直接估计类条件概率密度函数很困难。 解决的办法&#xff0c;把估计完全未知的概率密度转化为估计参数。这里就将概率密度估计问题转化为参数估计问题&#xff0c; 极大似然估计就是一种参数估计方法。当然了&#xff0c;概率密度函数的选取很重要&#xff0c;模型正…

在代码中忽略特定的编译告警

在移植别人的代码时&#xff0c;有些告警看着不爽&#xff0c;但又不想去改动原来的代码。可以在头文件中加一句&#xff1a; #pragma diag_suppress 111 即可忽略特定的编译告警。 其中&#xff0c;111是告警代码。 #pragma diag_suppress 111 比如&#xff0c;原始代码的…

【网络】UDP协议

UDP协议 一、传输层1、再谈端口号2、两个命令 二、UDP协议1、UDP协议格式2、UDP的解包和分用3、UDP的特点4、UDP使用注意事项5、基于UDP的应用层协议 一、传输层 我们以前在学习HTTP等应用层协议时&#xff0c;为了便于理解&#xff0c;简单的认为HTTP协议是将请求和响应直接发…

AI:75-基于生成对抗网络的虚拟现实场景增强

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

如何将系统盘MBR转GPT?无损教程分享!

什么是MBR和GPT&#xff1f; MBR和GPT是磁盘的两种分区形式&#xff1a;MBR&#xff08;主引导记录&#xff09;和GPT&#xff08;GUID分区表&#xff09;。 新硬盘不能直接用来保存数据。使用前应将其初始化为MBR或GPT分区形式。但是&#xff0c;如果您在MBR时需…

微服务-网关设计

文章目录 引言I 网关部署java启动jar包II 其他服务部署细节2.1 服务端api 版本号III 网关常规设置3.1 外部请求系统服务都需要通过网关访问3.2 第三方平台回调校验文件的配置IV 微服务日志跟踪4.1 打印线程ID4.2 封装线程池任务执行器4.3 将自身MDC中的数据复制给子线程4.4 微服…

「我在淘天做技术」音视频技术及其在淘宝内容业务中的应用

作者&#xff1a;李凯 一、前言 近年来&#xff0c;内容电商似乎已经充分融入到人们的生活中&#xff1a;在闲暇时间&#xff0c;我们已经习惯于拿出手机&#xff0c;从电商平台的直播间、或者短视频链接下单自己心仪的商品。 尽管优质的货品、实惠的价格、精致的布景、有趣的…

03-React事件处理 生命周期 Diffing算法

React事件处理 背景 1.通过onXxx属性指定事件处理函数(注意大小写) React使用的是自定义(合成)事件, 而不是使用的原生DOM事件 比如原生onclick的事件在React中变成了onClick&#xff0c;这么搞是为了更好的兼容性React中的事件是通过事件委托方式处理的(委托给组件最外层的…

MUYUCMS v2.1:一款开源、轻量级的内容管理系统基于Thinkphp开发

MuYuCMS&#xff1a;一款基于Thinkphp开发的轻量级开源内容管理系统&#xff0c;为企业、个人站长提供快速建站解决方案。它具有以下的环境要求&#xff1a; 支持系统&#xff1a;Windows/Linux/Mac WEB服务器&#xff1a;Apache/Nginx/ISS PHP版本&#xff1a;php > 5.6 (…

超级简单的springboot整合springsecurity oauth2第三方登录

前言 springboot整合springsecurity oauth2进行第三方登录&#xff0c;例如qq、微信、微博。网上一堆教程&#xff0c;并且很多都是旧版本的&#xff0c;篇幅又长&#xff0c;哔哩吧啦一大堆&#xff0c;就算你搞下来了&#xff0c;等下次版本升级或变更一下&#xff0c;你又不…

5分频【FPGA】

所以数据对齐晶振。 从第一个晶振开始&#xff1a; 5分频&#xff1a; 2.5晶振高电平&#xff0c;2.5晶振低电平 clk1是 32 clk2是23 需要 clk2下降沿【拉低】clk1上升沿【拉高】 clk_out clk1 & clk2; 推荐5分频&#xff1a;

一文带你速通Seata的XA模式

目录 XA规范协议 基本介绍 分布式事务处理模型角色 两阶段提交 Seata的XA的模式 基本介绍 具体使用 小结 XA规范协议 基本介绍 在讲解Seate中的XA模式之前我们先来了解了解什么是XA规范。XA 规范 是 X/Open 组织定义的分布式事务处理&#xff08;DTP&#xff0c;Distr…

计算机毕业设计项目选题推荐(免费领源码)java+Springboot+Mysql邻家优选超市线上线下购物系统小程序92713

摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高层次发展&#xff0c;由原来的感性认识向理性认识提高&#xff0c;管理工作的重要性已逐渐被人们所认识&#xff0c;科学化的管理&#xff0c;使信息存…

运维那些事儿|2023年,运维还有出路吗?

作为一名运维&#xff0c;不知道你有没有这样的感受。 觉得自己的工作没什么成长空间。每天装个系统、跑个机房、跑个脚本&#xff0c;忙来忙去也没忙出来什么名堂&#xff0c;含金量低不说&#xff0c;薪资也一直没见涨&#xff0c;所以你开始陷入迷茫&#xff0c;会疑惑&…