PyTorch深度学习(一)【线性模型、梯度下降、随机梯度下降】

news2024/11/14 3:14:42

这个系列是实战(刘二大人讲的pytorch)

建议把代码copy下来放在编译器查看(因为很多备注在注释里面)

线性模型(Linear Model):

import numpy as npimport matplotlib.pyplot as plt  #绘图的包x_data = [1.0, 2.0, 3.0]  #这两行代表数据集,一般x_data,y_data是要把它分开保存的,x表示输入样本y_data = [2.0, 4.0, 6.0]  #相同的索引表示一组样本,就是(1.0,2.0)表示一对样本,(2.0,4.0)表示一对样本。def forward(x):   #定义模型,取名“前馈模型”return x * w  #用x与w相乘返回(Linear Model)def loss(x, y):   #定义损失函数    y_pred = forward(x)return (y_pred - y) ** 2# 穷举法w_list = []       #权重值mse_list = []     #对应权重损失值for w in np.arange(0.0, 4.1, 0.1):  #权重间隔为0.1,从0.0开始取,4.0结束。[0.0,0.1,...,4.0]    print("w=", w)    l_sum = 0for x_val, y_val in zip(x_data, y_data):  #把x_data,y_data这两个列表里边的数据拿出来用zip拼成真实数据的x,y值        y_pred_val = forward(x_val)           #首先计算预测(可以不计算,主要只是打印一下结果看一下);y_pred_val是预测值,loss函数计算会用到        loss_val = loss(x_val, y_val)         #计算损失        l_sum += loss_val                     #损失求和        print('\t', x_val, y_val, y_pred_val, loss_val)    print('MSE=', l_sum / 3)    w_list.append(w)    mse_list.append(l_sum / 3)                #除以3,转成mse(均方误差)plt.plot(w_list, mse_list)        #绘图plt.ylabel('Loss')plt.xlabel('w')plt.show()

使用到的损失函数如下:

"y_pred"就是在求

zip函数:

【番外:可视化常用到的工具---visdom】

练习:实现线性模型(y=wx+b)并输出loss的3D图像。

import numpy as npimport matplotlib.pyplot as pltfrom mpl_toolkits.mplot3d import Axes3D#这里设函数为y=3x+2x_data = [1.0,2.0,3.0]y_data = [5.0,8.0,11.0]def forward(x):return x * w + bdef loss(x,y):y_pred = forward(x)return (y_pred-y)*(y_pred-y)mse_list = []W=np.arange(0.0,4.1,0.1)B=np.arange(0.0,4.1,0.1)[w,b]=np.meshgrid(W,B)l_sum = 0for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)print(y_pred_val)loss_val = loss(x_val, y_val)l_sum += loss_valfig = plt.figure()ax = Axes3D(fig)fig.add_axes(ax)ax.plot_surface(w, b, l_sum/3)plt.show(block=True)

梯度下降

(鞍点:梯度为0。陷入鞍点没办法迭代)

cost公式:(cost function是对所有的样本)

代码:

import matplotlib.pyplot as plt# 准备训练集数据x_data = [1.0, 2.0, 3.0]  #两个列表,分别表示x和y的值,(1.0,2.0)表示第一条数据样本y_data = [2.0, 4.0, 6.0]  #(2.0,4.0)表示第二条数据样本# initial guess of weightw = 1.0   #初始权重猜测# define the model linear model y = w*xdef forward(x):   #定义前馈计算return x * w    #y^# define the cost function MSEdef cost(xs, ys):   #把所有的数据都拿进来    cost = 0for x, y in zip(xs, ys):        y_pred = forward(x)    #算y^        cost += (y_pred - y) ** 2return cost / len(xs)      #MSE(平均损失的计算)# define the gradient function  gddef gradient(xs, ys):   #求梯度    grad = 0for x, y in zip(xs, ys):        grad += 2 * x * (x * w - y)return grad / len(xs)epoch_list = []cost_list = []print('predict (before training)', 4, forward(4))for epoch in range(100):    #训练过程(100轮)    cost_val = cost(x_data, y_data)  #计算当前损失值,也就是cost    grad_val = gradient(x_data, y_data)   #求梯度    w -= 0.01 * grad_val  # 0.01 learning rate    #学习率*梯度    print('epoch:', epoch, 'w=', w, 'loss=', cost_val)    epoch_list.append(epoch)    cost_list.append(cost_val)print('predict (after training)', 4, forward(4))plt.plot(epoch_list, cost_list)plt.ylabel('cost')plt.xlabel('epoch')plt.show()

【番外:“指数加权均值”方法能够将cost变得更平滑】

在大多数的情况下,得到的loss图像形状趋势都是如上图所示,如果出现右边有又上去了的情况,则说明训练发散了,这次训练失败了。

训练失败的情况有很多,其中最常见的是:学习率取得太大。(可以将学习率调小再看看效果)

随机梯度下降

(只用一个样本,即使陷入了鞍点,也也有可能跨过这个鞍点向前推进找最优点)

公式:(单个样本的损失函数对权重求导,然后进行更新)

代码:

import matplotlib.pyplot as pltx_data = [1.0, 2.0, 3.0]y_data = [2.0, 4.0, 6.0]w = 1.0def forward(x):return x * w# calculate loss functiondef loss(x, y):    y_pred = forward(x)   #y^return (y_pred - y) ** 2    #loss# define the gradient function  sgddef gradient(x, y):return 2 * x * (x * w - y)      #梯度epoch_list = []loss_list = []print('predict (before training)', 4, forward(4))for epoch in range(100):for x, y in zip(x_data, y_data):        grad = gradient(x, y)    #对每一个样本求梯度,loss对w求梯度        w = w - 0.01 * grad  # update weight by every grad of sample of training set    更新        print("\tgrad:", x, y, grad)        l = loss(x, y)    #计算现在的损失    print("progress:", epoch, "w=", w, "loss=", l)    epoch_list.append(epoch)    loss_list.append(l)print('predict (after training)', 4, forward(4))plt.plot(epoch_list, loss_list)plt.ylabel('loss')plt.xlabel('epoch')plt.show()

性能好,但时间复杂度太高,没有并行性。

【番外:Batch。(性能和时间复杂度上取折中)批量的随机梯度下降法。

就是说如果你全都丢到一起,性能不好;全都分开呢,时间复杂度不好。

因此可以若干个分为一组,每次用这一租样本去求相应的梯度,然后进行更新。这个就叫做Batch。】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1013027.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cesium 地球网格构造

Cesium 地球网格构造 Cesium原理篇:3最长的一帧之地形(2:高度图) HeightmapTessellator 用于从高程图像创建网格。提供了一个函数 computeVertices,可以根据高程图像创建顶点数组。 该函数的参数包括高程图像、高度数据的结构、网格宽高、…

Gradle的简介、下载、安装、配置及使用流程

Gradle的简介、下载、安装、配置及使用流程 1.Gradle的简介 Gradle是一个基于Apache Ant和Apache Maven概念的项目自动化构建开源工具。它使用一种基于Groovy的特定领域语言(DSL)来声明项目设置,也增加了基于Kotlin语言的kotlin-based DSL,抛弃了基于X…

AI项目六:基于YOLOV5的CPU版本部署openvino

若该文为原创文章,转载请注明原文出处。 一、CPU版本DEMO测试 1、创建一个新的虚拟环境 conda create -n course_torch_openvino python3.8 2、激活环境 conda activate course_torch_openvino 3、安装pytorch cpu版本 pip install torch torchvision torchau…

vcruntime140_1.dll修复方法分享,教你安全靠谱的修复手段

在使用Windows操作系统的过程中,我们有时会遇到vcruntime140_1.dll文件丢失或损坏的情况。本文将详细介绍vcruntime140_1.dll的作用,以及多种解决方法和修复该文件时需要注意的问题,希望能帮助读者更好地处理这一问题。 一.vcruntime140_1.dl…

数据结构——【堆】

一、堆的相关概念 1.1、堆的概念 1、堆在逻辑上是一颗完全二叉树(类似于一颗满二叉树只缺了右下角)。 2、堆的实现利用的是数组,我们通常会利用动态数组来存放元素,这样可以快速拓容也不会很浪费空间,我们是将这颗完…

【Java】SpringData JPA快速上手,关联查询,JPQL语句书写

JPA框架 文章目录 JPA框架认识SpringData JPA使用JPA快速上手方法名称拼接自定义SQL关联查询JPQL自定义SQL语句 ​ 在我们之前编写的项目中,我们不难发现,实际上大部分的数据库交互操作,到最后都只会做一个事情,那就是把数据库中的…

电容 stm32

看到stm32电源部分都会和电容配套使用,所以对电容的作用产生了疑惑 电源 负电荷才能在导体内部自由移动,电池内部的化学能驱使着电源正电附近的电子移动向电源负极区域。 电容 将电容接上电池,电容的两端一段被抽走电子,一端蓄积…

【STL容器】vector

文章目录 前言vector1.1 vector的定义1.2 vector的迭代器1.3 vector的元素操作1.3.1 Member function1.3.2 capacity1.3.3 modify 1.4 vector的优缺点 前言 vector是STL的容器,它提供了动态数组的功能。 注:文章出现的代码并非STL库里的源码&#xff0c…

C++ PrimerPlus 复习 第三章 处理数据

第一章 命令编译链接文件 make文件 第二章 进入c 第三章 处理数据 文章目录 C变量的命名规则;C内置的整型——unsigned long、long、unsigned int、int、unsigned short、short、char、unsigned char、signed char和bool;如何知道自己计算机类型宽度获…

Jenkins Maven pom jar打包未拉取最新包解决办法,亲测可行

Jenkins Maven pom jar打包未拉取最新包解决办法,亲测可行 1. 发布新版的snapshots版本的jar包,默认Jenkins打包不拉取snapshots包2. 设置了snapshot拉取后,部分包还未更新,需要把包版本以snapshot结尾3. IDEA无法更新snapshots包…

超炫的开关效果

超炫的开关动画 代码如下 <!DOCTYPE html> <html> <head><meta charset"UTF-8"><title>Switch</title><style>/* 谁家好人 一个按钮写三百多行样式代码 &#x1f622;&#x1f622;*/*,*:after,*:before {box-sizing: bor…

代码随想录--栈与队列-用队列实现栈

使用队列实现栈的下列操作&#xff1a; push(x) -- 元素 x 入栈pop() -- 移除栈顶元素top() -- 获取栈顶元素empty() -- 返回栈是否为空 &#xff08;这里要强调是单向队列&#xff09; 用两个队列que1和que2实现队列的功能&#xff0c;que2其实完全就是一个备份的作用 impo…

产教融合 | 力软联合重庆科技学院开展低代码应用开发培训

近日&#xff0c;力软与重庆科技学院联合推出了为期两周的低代码应用开发培训课程&#xff0c;来自重庆科技学院相关专业的近百名师生参加了此次培训。 融合研学与实践&#xff0c;方能成为当代数字英才。本次培训全程采用线下模式&#xff0c;以“力软低代码平台”为软件开发…

聚焦真实用例,重仓亚洲,孙宇晨畅谈全球加密新格局下的主动变革

「如果能全面监管&#xff0c;加密行业仍有非常大的增长空间。目前加密用户只有 1 亿左右&#xff0c;如果监管明朗&#xff0c;我们可以在 3-5 年内获得 20-30 亿用户。」—— 孙宇晨 9 月 14 日&#xff0c;波场 TRON 创始人、火币 HTX 全球顾问委员会成员孙宇晨受邀出席 20…

RHCSA的一些简单操作命令

目录 1、查看Linux版本信息 2、ssh远程登陆 3、解析[zxlocalhost ~]$ 与 [rootlocalhost ~]# 4、退出命令exit 5、su——switch user 6、打印用户所处的路径信息pwd 7、修改路径 8、输出文件\目录信息 9、重置root账号密码 10、修改主机名称 1&#xff09;临时修改…

SkyWalking入门之Agent原理初步分析

一、简介 当前稍微上点体量的互联网公司已经逐渐采用微服务的开发模式&#xff0c;将之前早期的单体架构系统拆分为很多的子系统&#xff0c;子系统封装为微服务&#xff0c;彼此间通过HTTP协议RESET API的方式进行相互调用或者gRPC协议进行数据协作。 早期微服务只有几个的情况…

三种方式部署单机版Minio,10行命令干就完了~

必要步骤&#xff1a;安装MinIO 拉取MinIO镜像 docker pull quay.io/minio/minio 创建文件挂载点 mkdir /home/docker/MinIO/data &#xff08;文件挂载点映射&#xff0c;默认是/mydata/minio/data&#xff0c;修改为/home/docker/MinIO&#xff0c;文件存储位置自行修改&…

随笔-嗨,中奖了

好久没有动笔了&#xff0c;都懒惰了。 前段时间&#xff0c;老妹凑着暑假带着双胞胎的一个和老妈来了北京&#xff0c;听着小家伙叫舅舅&#xff0c;还是挺稀奇的。周末带着他们去了北戴河&#xff0c;全家人都是第一次见大海&#xff0c;感觉&#xff0c;&#xff0c;&#…

qiankun 乾坤主应用访问微应用css静态图片资源报404

发现static前没有加我指定的前缀 只有加了后才会出来 解决方案: env定义前缀 .env.development文件中 # static前缀 VUE_APP_PUBLIC_PREFIX"" .env.production文件中 # static前缀 VUE_APP_PUBLIC_PREFIX"/szgl" settings文件是封了一下src\settings…

测试平台前端部署

这里写目录标题 一、前端代码打包1、打包命令2、打包完成后,将dist文件夹拷贝到nginx文件夹中3、重新编写default.conf4、将之前启动的容器进行停止并且删除,再重新创建容器5、制作Dockerfile二、编写Dockerfile一、前端代码打包 1、打包命令 npm run build2、打包完成后,…