(3)深度学习学习笔记-简单线性模型

news2024/9/22 13:46:01

文章目录

  • 一、线性模型
  • 二、实例
    • 1.pytorch求导功能
    • 2.简单线性模型(人工数据集)
  • 来源


一、线性模型

一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。
那么房价y就可以认为y=w1x1+w2x2+w3x3+b,w为权重,b为偏差。
第一步
在这里插入图片描述
线性模型可以看做是单层(带权重的层是1层)神经网络。
在这里插入图片描述
第二步:
定义loss,衡量预估质量:真实值和预测值的差距
在这里插入图片描述
这里带1/2是方便求导的时候把2消去。
训练数据:收集数据来决定权重和偏差
训练损失:loss=1/n∑[(真实值-预测值(xi和权重的内积-偏差))平方]。目标是找到最小的loss
在这里插入图片描述
第三步:优化
优化方法:梯度下降。先挑选一个初值w0,之后不断更新w0使他接近最优解。更新方法是wt=wt-1 - 学习速率
梯度。
在这里插入图片描述
Learning rate不能太小(到达一个点要走很多步),也不能太大(一直震荡没有真的下降)
在这里插入图片描述
在这里插入图片描述
在整个训练集上梯度下降太贵,跑一次模型可能要数分钟/小时。所以采用小批量随机梯度下降,随机采样b个样本用这b个样本来近似损失。b不能太大也不能太小
在这里插入图片描述

二、实例

1.pytorch求导功能

代码如下:

# 自动求导
import torch
# 假设对函数y=2xT x关于列向量求导
x = torch.arange(4.0)
# 算y关于x的梯度之前,需要一个地方来存储梯度
x.requires_grad_(True)  # 等价于x=torch.arange(4.0,requires_grad=True)
print(x.grad)  # 默认值是None y关于x的导数存在这里

y=2*torch.dot(x,x)
y.backward() # 求导
print(x.grad)
print(x.grad==4*x)

# 在默认情况下,PyTorch会累积梯度,需要清除之前的值
x.grad.zero_()

y = x.sum()
y.backward()
print(x.grad)

x.grad.zero_()
y=x*x
u=y.detach()# 把y当成常数而不是x的函数
z=u*x
z.sum().backward()
print(x.grad==u)

2.简单线性模型(人工数据集)

代码如下:

# 构建人工数据集(好处是知道w和b)
# 根据w=[2,-3.4] b=4.2 和噪声生成数据集和标签 y=Xw+b+噪声
import numpy as np
import torch
from torch import nn
from torch.utils import data

# 生成数据
def synthetic_data(w, b, num_examples):
    """生成y=Xw+b+噪声"""
    X = np.random.normal(0, 1, (num_examples, len(w)))  # 均值为0,方差为1,num_ex个样本,列数=w的个数
    y = np.dot(X, w) + b  # y=Xw+b
    y += np.random.normal(0, 0.01, y.shape)  # 加上随机噪音
    x1 = torch.tensor(X, dtype=torch.float32)  # 把np转化为torch
    y1 = torch.tensor(y, dtype=torch.float32)
    return x1, y1.reshape((-1, 1))  # 列向量反馈


# 读取数据
def load_array(data_arrays, batch_size, is_train=True):
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)  # shuffle:是否需要随机打乱


true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

batch_size = 10
data_iter = load_array((features, labels), batch_size)
# 使用iter构造Python迭代器,并使用next从迭代器中获取第一项。
print(next(iter(data_iter)))

# 定义模型
# nn是神经网络的缩写
net = nn.Sequential(nn.Linear(2, 1))  # 输入维度是2 输出是1 sequential相当于一个list of layer

# 初始化参数 net[0]访问这个layer
net[0].weight.data.normal_(0, 0.01)  # normal:使用正态分布替换weight的值,均值为0,方差为0.01
net[0].bias.data.fill_(0)  # bias直接设为0

# 定义loss:mseloss类
loss = nn.MSELoss()

# 定义优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)  # net.patameters():所有参数  ,lr:learning rate

# 训练
num_epochs = 3
for epoch in range(num_epochs):  # 对所有数据扫一遍
    for X, y in data_iter:  # 拿出一个批量大小的x和y
        l = loss(net(X), y)  # x和y的小批量损失
        trainer.zero_grad()  # 梯度清零
        l.backward()  # 计算梯度
        trainer.step()  # 模型更新
    l = loss(net(features), labels)  # 计算损失
    print(f'epoch {epoch + 1}, loss {l:f}')


来源

b站 跟李沐学AI 动手学深度学习v2 08

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/697613.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据分析与机器学习的结合:实现智能决策

章节一:引言 在当今数字化时代,大数据分析和机器学习已经成为推动技术创新和业务发展的关键要素。大数据的快速增长和复杂性使得传统的数据处理方法变得不再有效。而机器学习作为一种自动化的数据分析方法,能够从海量数据中挖掘出有价值的信…

QSS样式设置及语法规则

QSS(Qt Style Sheets)是Qt的一个功能强大的样式表语言。它类似于CSS(Cascading Style Sheets),可以用于定义和控制应用程序的外观和样式。QSS可以应用于Qt部件(Widgets)和绘制元素,以…

集合专题----Map篇

1、Map 接口和常用方法 (1)Map 接口实现类的特点 ① Map与Collection并列存在(即平行关系)。Map用于保存具有映射关系的数据:Key-Value; ② Map 中的 key 和 value 可以是任何引用类型的数据,…

先平移再旋转和先旋转再平移的区别

对于一个刚体,以汽车为例,先旋转再平移和先平移再旋转有没有区别要看这个平移旋转是以什么坐标系为基准 如果平移和旋转都以小车坐标系为基准,二者是有区别的 如果平移旋转以世界坐标系为基准,二者是没有区别的 看图就明白了 所…

v8-tc39-ecma262:concat,不只是合并数组

如上图,解释如下: 如果是对象o,转换为对象新建数组A设n0,用于最后赋值给A,确保A的长度正确预先把值设置到items(这里不知何意?)循环items,设置元素为E E是否可展开如果可展开 有len下标,则获取…

LLM - 第2版 GLM 中文对话模型 ChatGLM2-6B 服务配置 (2)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131445696 ChatGLM-6B v1 工程:基于 ChatGLM-6B 模型搭建 ChatGPT 中文在线聊天 (1)ChatGLM2-6B v2 工程:第2版 …

react umi中使用svg线上图片不加载问题

参考链接&#xff1a; https://www.jianshu.com/p/c927122a6e82 前言&#xff1a; 在react项目中&#xff0c;我们本地通过img标签的src使用svg图片是可以加载的&#xff0c;但是发布到线上图片加载不出来。 import stopImg from /images/stop.svg; <img src{stopImg }/&…

Transformer时间序列:PatchTST引领时间序列预测进

Transformer时间序列&#xff1a;PatchTST引领时间序列预测进 引言为什么transformer框架可以应用到时间序列呢统计学模型深度学习模型 PatchTSTPatchTST模型架构原理。通道独立性Patchingpatching的优点Transformer编码器 利用表示学习改进PatchTST使用PatchTST模型进行预测初…

深入理解 Golang: 聚合、引用和接口类型的底层数据结构

Go 中有基础类型、聚合类型、引用类型和接口类型。基础类型包括整数、浮点数、布尔值、字符串&#xff1b;聚合类型包括数组、结构体&#xff1b;引用类型包括指针、切片、map、function、channel。在本文中&#xff0c;介绍部分聚合类型、引用类型和接口类型的底层表示及原理。…

如何在Microsoft Excel中快速筛选数据

你通常如何在 Excel 中进行筛选?在大多数情况下,通过使用自动筛选,以及在更复杂的场景中使用高级过滤器。 使用自动筛选或 Excel 中的内置比较运算符(如“大于”和“前10项”)来显示所需数据并隐藏其余数据。筛选单元格或表范围中的数据后,可以重新应用筛选器以获取最新…

数据结构与算法基础-学习-25-图之MST(最小代价生成树)之Prim(普利姆)算法

一、生成树概念 1、所有顶点均由边连接在一起&#xff0c;但不存在回路的图。 2、一个图可以有许多棵不同的生成树。 二、生成树特点 1、生成树的顶点个数与图的顶点个数相同。 2、生成树是图的极小连通子图&#xff0c;去掉一条边则非连通。 3、一个有n个顶点的连通图的生…

stm32f103c8t6移植U8g2

U8g2代码下载&#xff1a; https://github.com/olikraus/u8g2 1&#xff0c;准备一个正常运行的KEIL5 MDK模板 2&#xff0c;下载u8g2的源码和 u8g2的STM32实例模板 源码: https://github.com/olikraus/u8g2 STM32实例模板: https://github.com/nikola-v/u8g2_template_stm32f…

100天精通Golang(基础入门篇)——第11天:深入解析Go语言中的切片(Slice)及常用函数应用

&#x1f337; 博主 libin9iOak带您 Go to Golang Language.✨ &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &#x1f30a; 《I…

期望最大化注意力网络 EMANet

论文&#xff1a;Expectation-Maximization Attention Networks for Semantic Segmentation Github&#xff1a;https://github.com/XiaLiPKU/EMANet ICCV2019 oral 论文提出的期望最大化注意力机制Expectation- Maximization Attention (EMA)&#xff0c;摒弃了在全图上计算注…

再述时序约束

再述时序约束 一、为什么要加时序约束&#xff1f;二、时序分析是什么&#xff1f;三、时序分析的一些基本概念三、 时序分析的一些基本公式 一、为什么要加时序约束&#xff1f; 一次笔者在调试HDMI输出彩条&#xff0c;出现彩条时有时无现象&#xff0c;笔者视频输出芯片的驱…

leecode-数组多数-摩尔投票法

题目 题目 分析 最开始思路&#xff1a;排序&#xff0c;然后取nums[n/2]&#xff0c;但是时间复杂度不过关。 摩尔投票法&#xff0c;学到了&#xff01; 代码 class Solution { public:int majorityElement(vector<int>& nums) {//摩尔投票int cnt0;int targ…

计算机二级c语言题库

计算机C语言二级考试&#xff08;60道程序设计&#xff09; 第1道 请编写一个函数fun,它的功能是:将ss所指字符串中所有下标为奇数位置上的字母转换成大写&#xff08;若该位置上不是字母&#xff0c;则不转换&#xff09;。 例如&#xff0c;若输入"abc4EFG"&…

OpenCV学习笔记 | ROI区域选择提取 | Python

摘要 ROI区域是指图像中我们感兴趣的特定区域&#xff0c;OpenCV提供了一些函数来选择和提取ROI区域&#xff0c;我们可以使用OpenCV的鼠标事件绑定函数&#xff0c;然后通过鼠标操作在图像上绘制一个矩形框&#xff0c;该矩形框即为ROI区域。本文将介绍代码的实现以及四个主要…

opencv编译

文章目录 一、编译前工作二、编译安装1、Windows2、Linux 一、编译前工作 进入下载页面https://github.com/opencv/opencv&#xff0c;下载指定.tar.gz源码包&#xff0c;例如&#xff1a;opencv-4.7.0.tar.gz。解压到指定目录。 二、编译安装 opencv构建时&#xff0c;需要…

使用docker搭建hadoop集群

1.下载安装docker 2.启动docker 3.配置docker镜像 4.获取hadoop镜像 5.拉取hadoop镜像 6.运行容器 7.进入容器 8.配置免密 9.格式化节点 10.启动节点 11.查看节点信息 (img-CBr9VbGk-1687962511910)] 11.查看节点信息