PyTorch深度学习实践——线性模型、梯度下降算法、反向传播

news2024/10/7 8:28:09

1、线性回归

参考资料1:https://blog.csdn.net/bit452/article/details/109627469
参考资料2:http://biranda.top/Pytorch%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0003%E2%80%94%E2%80%94%E7%BA%BF%E6%80%A7%E6%A8%A1%E5%9E%8B/#%E7%BA%BF%E6%80%A7%E6%A8%A1%E5%9E%8B

1.1 一元线性回归y=wx代码

要求: 实现线性模型(y=wx)并输出loss的图像。
代码:

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
#前馈计算
def forward(x):
    return x * w
#求loss
def loss(x, y):
    y_pred = forward(x)
    return (y_pred-y)*(y_pred-y)

w_list = []
mse_list = []
#从0.0一直到4.1以0.1为间隔进行w的取样
for w in np.arange(0.0,4.1,0.1):
    print("w=", w)
    l_sum = 0
    for x_val,y_val in zip(x_data,y_data):
        y_pred_val = forward(x_val)
        loss_val = loss(x_val,y_val)
        l_sum += loss_val
        print('\t',x_val,y_val,y_pred_val,loss_val)
    print("MSE=",l_sum/3)
    w_list.append(w)
    mse_list.append(l_sum/3)

#绘图
plt.plot(w_list,mse_list)
plt.ylabel("Loss")
plt.xlabel('w')
plt.show()

结果:
在这里插入图片描述
在这里插入图片描述

1.2 一元线性回归y=wx+b代码

要求: 实现线性模型(y=wx+b)并输出loss的3D图像。
代码:

import numpy as np
import matplotlib.pyplot as plt;
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

#线性模型
def forward(x,w,b):
    return x * w+ b

#损失函数
def loss(x, y,w,b):
    y_pred = forward(x,w,b)
    return (y_pred - y) * (y_pred - y)

def mse(w,b):
    l_sum = 0
    for x_val, y_val in zip(x_data, y_data):
        y_pred_val = forward(x_val,w,b)
        loss_val = loss(x_val, y_val,w,b)
        l_sum += loss_val
        print('\t', x_val, y_val, y_pred_val, loss_val)
    print('MSE=', l_sum / 3)
    return  l_sum/3

#迭代取值,计算每个w取值下的x,y,y_pred,loss_val
mse_list = []

##画图
##定义网格化数据
b_list=np.arange(-30,30,0.1)
w_list=np.arange(-30,30,0.1)

##生成网格化数据
xx, yy = np.meshgrid(b_list, w_list,sparse=False, indexing='ij')

##每个点的对应高度
zz=mse(xx,yy)

fig = plt.figure()#定义图像窗口
ax = Axes3D(fig)#在窗口上添加3D坐标轴
ax.plot_surface(xx, yy, zz,rstride=10, cstride=10, cmap=cm.viridis)#生成曲面,cm.viridis是颜色
# rstride(row)指定行的跨度,cstride(column)指定列的跨度
plt.show()

结果:
在这里插入图片描述

2、梯度下降算法

参考资料1:https://blog.csdn.net/bit452/article/details/109637108
参考资料2:http://biranda.top/Pytorch%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0004%E2%80%94%E2%80%94%E6%A2%AF%E5%BA%A6%E4%B8%8B%E9%99%8D%E7%AE%97%E6%B3%95/#%E9%97%AE%E9%A2%98%E8%83%8C%E6%99%AF

2.1 梯度下降算法

代码:

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
_cost = []
w = 1.0
#前馈计算
def forward(x):
    return x * w
#求MSE
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs,ys):
        y_pred = forward(x)
        cost += (y_pred-y) ** 2
    return cost/len(xs)
#求梯度
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs,ys):
        temp = forward(x)
        grad += 2*x*(temp-y)
    return grad / len(xs)

for epoch in range(100):
     cost_val = cost(x_data, y_data)
     _cost.append(cost_val)
     grad_val = gradient(x_data, y_data)
     w -= 0.01*grad_val
     print("Epoch: ",epoch, "w = ",w ,"loss = ", cost_val)
print("Predict(after training)",4,forward(4))

#绘图
plt.plot(range(100),_cost)
plt.ylabel("Cost")
plt.xlabel('Epoch')

plt.show()

结果:
在这里插入图片描述
在这里插入图片描述

2.2 随机梯度下降算法

代码:

#随机梯度下降
import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
_cost = []
w = 1.0
#前馈计算
def forward(x):
    return x * w
#求单个loss
def loss(x, y):
    y_pred = forward(x)
    return (y_pred-y) ** 2
#求梯度 随机梯度下降的 loss是计算一个训练数据的损失
def gradient(x, y):
    return 2*x*(x*w-y)
print("Predict(after training)",4,forward(4))

for epoch in range(100):
    for x, y in zip(x_data,y_data):
        grad=gradient(x,y)
        w -= 0.01*grad
        print("\tgrad:  ",x,y,grad)
        l = loss(x,y)
    print("progress: ",epoch,"w=",w,"loss=",l)
print("Predict(after training)",4,forward(4))

结果:
在这里插入图片描述

2.3 区别:

随机梯度下降法和梯度下降法的主要区别在于:

1、损失函数由cost()更改为loss()。cost是计算所有训练数据的损失,loss是计算一个训练数据的损失。对应于源代码则是少了两个for循环。

2、梯度函数gradient()由计算所有训练数据的梯度更改为计算一个训练数据的梯度。

3、本算法中的随机梯度主要是指,每次拿一个训练数据来训练,然后更新梯度参数。本算法中梯度总共更新100(epoch)x3 = 300次。梯度下降法中梯度总共更新100(epoch)次。

3、反向传播

参考资料1:https://blog.csdn.net/bit452/article/details/109643481
参考资料2:http://biranda.top/Pytorch%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0005%E2%80%94%E2%80%94%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95/#%E9%97%AE%E9%A2%98%E6%8F%90%E5%87%BA

代码:

import torch

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

#赋予tensor中的data
w = torch.Tensor([1.0])#w初值为1
#设定w需要计算梯度grad
w.requires_grad = True

#模型y=x*w 建立计算图
def forward(x):
    '''
    w为Tensor类型
    x强制转换为Tensor类型
    通过这样的方式建立计算图
    '''
    return x * w

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

print ("predict  (before training)", 4, forward(4).item())

for epoch in range(100):
    for x,y in zip(x_data,y_data):
        #创建新的计算图
        l = loss(x,y)
        #进行反馈计算,此时才开始求梯度,此后计算图进行释放
        l.backward()
        #grad.item()取grad中的值变成标量
        print('\tgrad:',x, y, w.grad.item())
        #单纯的数值计算要利用data,而不能用张量,否则会在内部创建新的计算图
        w.data = w.data - 0.01 * w.grad.data
        #把权重梯度里的数据清零
        w.grad.data.zero_()
    print("progress:",epoch, l.item())

print("predict (after training)", 4, forward(4).item())

结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/30280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PC_多处理器

文章目录多处理器单指令单数据流SISD结构单指令流多数据流SIMD结构向量处理器多指令流单数据流MISD结构多指令多数据流MIMD结构小结硬件多线程细粒度多线程粗粒度多线程同时多线程多核处理器共享内存多处理器多处理器 常规的单处理器属于SISD常规多处理器属于MIMD 单指令单数…

腾格尔十月天传媒联手《巴林塔娜》,2255万粉丝多少买票支持

曾几何时,木桶原理非常流行,意思就是一个木桶能够盛多少水,取决于最短一块板的长度。可是随着社会的发展,木桶原理已经被淘汰,只要你拥有了团队合作,就可以统协作取长补短。 就拿有着“草原歌神”之称的腾格…

你的知识库能提高工作效率的7个原因

知识就是力量。但到目前为止,光有知识是不够的——你使用这些信息的方式让你领先于竞争对手。如果使用正确,知识库软件可以帮助您提供更好的服务,培训您的员工,并成为您的行业权威。拥有一个有效的知识库不仅会影响你在内部开展业…

Android assets

1.应用程序资源管理器assets assets就是apk工程中的一个普通目录,在每个工程的根目录下都可以发现(或者可以自己创建)一个assets目录。 assets目录用于专门保存各种外部文件,比如图像、音视频、配置文件、字体、自带数据库等。它之所以适合用来管理这些…

数据库mysql操作语言, DDL,DML,DQL

文章目录一. 数据库1. 数据库基本概念2. 数据库管理系统3. 数据库与表的概念二. 连接数据库的方式三. 如何操作DBMSSQL语句分类1. DDL 数据定义语言查看DBMS中已有的数据库数据库相关操作新建一个数据库查看数据库信息删除数据库使用一个数据库(切换一个数据库)表相关操作创建表…

HOOPS/MVO技术概述

更多参见:HOOPS学习笔记 MVO 1.引言 HOOPS/MVO是一个C类库,位于HOOPS 3D图形系统(HOOPS/3DGS)之上。它有一个模型/视图/操作员架构,封装了各种HOOPS/3DGS数据结构和概念,并提供了一系列通用应用程序级逻辑…

【无人机】基于粒子群优化干扰受限下无人机群辅助网络附matlab代码

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

Xception --tensorflow2.x

简介 Xception和SqueezeNet一样,是一种降低参数量的轻量级神经网络,它主要使用了 深度分离卷积(Depthwise separable convolution)结构,该结构替换了原来的Inception中的多尺寸卷积结构。这里需要弄清深度分离卷积(D…

【创建型设计模式-单例模式】一文搞懂单例模式的使用场景及代码实现的7种方式

1.什么是单例模式 在了解单例模式前,我们先来看一下它的定义: 确保一个类只有一个实例,而且自行实例化并且自行向整个系统提供这个实例,这个类称为单例类,它提供全局访问的方法, 单例模式是一种对象的创建型…

微型计算机原理速通期末复习

文章目录微机基础原码、反码、补码、移码溢出实数型功能结构8086/8088内部结构80286内部结构80386/80486内部结构标志寄存器FLAGS寄存器阵列段寄存器寻址标志寄存器EFLAGS分段结构数据寻址方式立即寻址直接寻址寄存器寻址寄存器间接寻址寄存器相对寻址基址-变址寻址基址-变址-相…

Solidity vs. Vyper:不同的智能合约语言的优缺点

本文探讨以下问题:哪种智能合约语言更有优势,Solidity 还是 Vyper?最近,关于哪种是“最好的”智能合约语言存在很多争论,当然了,每一种语言都有它的支持者。 这篇文章是为了回答这场辩论最根本的问题&…

磨金石教育摄影技能干货分享|中国风摄影大师——郎静山

说到中国风摄影,你想到的画面是什么样子的?故宫、长城、苏州园林、大红灯笼高高挂,反正离不开传承了千八百年的古建筑。仿佛没有了这些历史古董的元素就没有中国味道似的。 其实中国风,其内核应该是传统的审美观念和哲学思想。中…

【雷丰阳-谷粒商城 】课程概述

持续学习&持续更新中… 学习态度:守破离 【雷丰阳-谷粒商城 】课程概述该电商项目与其它项目的区别项目简介项目背景电商模式谷粒商城项目技术&特色项目前置要求谷粒商城-微服务架构图谷粒商城-微服务划分图参考该电商项目与其它项目的区别 互联网大型项目…

深入linux内核架构--内存管理

【推荐阅读】 代码大佬的【Linux内核开发笔记】分享,前人栽树后人乘凉! 一文了解Linux内核的Oops 一篇长文叙述Linux内核虚拟地址空间的基本概括 路由选择协议——RIP协议 深入理解Intel CPU体系结构【值得收藏!】 内存体系结构 1. UM…

银行测试人员谈测试需求

今天呢,想用故事说话,先看看啥叫用户需求挖掘。其实看完故事之后,我自己颇为震撼,请看。 故事一: 100多年前,福特公司的创始人亨利福特先生到处跑去问客户:“您需要一个什么样的更好的交通工具…

loganalyzer 展示数据库中的日志

1 实验目标: 利用rsyslog日志服务,将收集的日志记录于MySQL中,通过loganalyzer 展示数据库中的日志 2 环境准备 三台主机: 一台日志服务器,利用上一个案例实现,IP:192.168.100.100一台数据库…

【Java八股文总结】之数据结构

文章目录数据结构一、概念1、时间复杂度与空间复杂度2、常见算法时间复杂度3、Comparable二、常见的排序算法1、直接插入排序2、希尔排序3、选择排序4、堆排序5、冒泡排序6、快速排序7、归并排序8、二分查找算法Q:什么时候需要结束呢?三、线性表1、概念2…

使用 Footprint Analytics, 快速搭建区块链数据应用

Nov 2022, danielfootprint.network 如果你有一个处理 NFTs 或区块链的网站或应用程序,你可以在你的平台上直接向用户展示数据,以保持他们在网站或者应用内的参与,而不是链接以及跳出到其他网站。 对于任何区块链应用或者媒体、信息网站来说…

秦皇岛科学选育新品种 国稻种芯·中国水稻节:河北秸秆变肥料

秦皇岛科学选育新品种 国稻种芯中国水稻节:河北秸秆变肥料 秦皇岛新闻网 记者李妍 冀时客户端报道(河北台 张志刚 米弘钊 赵永鑫 通讯员 赵力楠) 新闻中国采编网 中国新闻采编网 谋定研究中国智库网 中国农民丰收节国际贸易促进会 国稻种芯…

无线通信技术概览

电生磁,磁生电 电场和磁场的关系,简而言之就是:变化的电场产生磁场,变化的磁场产生电场。 电荷的定向移动产生电流,电荷本身产生电场。电流是移动的电场。静止的电荷产生静止的电场,运动的电荷产生运动的电…