动手学深度学习v2笔记 —— 线性回归 + 基础优化算法

news2024/9/22 15:34:32

二 动手学深度学习v2 —— 线性回归 + 基础优化算法


目录:

  1. 线性回归
  2. 基础优化方法

1. 线性回归
总结

  • 线性回归是对n维输入的加权,外加偏差
  • 使用平方损失来衡量预测值和真实值的差异
  • 线性回归有显示解
  • 线性回归可以看作是单层神经网络

2. 基础优化方法
梯度下降
在这里插入图片描述
小批量随机梯度下降
在这里插入图片描述

3. 总结

  • 梯度下降通过不断沿着反梯度方向更新参数求解
  • 小批量随机梯度下降是深度学习默认的求解算法
  • 两个重要的超参数是批量大小和学习率

4. 线性回归的从零开始实现

import torch
import random
def synthetic_data(w, b, num):
    "生成 y = Xw + b + 噪声"
    ''' 根据带有噪声的线性模型构造一个人造数据集。 我们使用线性模型参数𝐰 = [2,−3.4]⊤
    w=[2,−3.4]⊤、𝑏 = 4.2和噪声项𝜖ϵ生成数据集及其标签:𝐲 = 𝐗𝐰 + 𝑏 + 𝜖'''
    X = torch.normal(0, 1, (num, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape(-1, 1)

true_w = torch.tensor([2, -3.4])
true_b = 4.2
num_examples = 1000
batch_size = 10
features, labels = synthetic_data(true_w, true_b, num_examples)

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        num_indices = torch.tensor(indices[i: min(i +
                                                  batch_size, num_examples)])
        yield features[num_indices], labels[num_indices]

def linreg(X, w, b):
    return torch.matmul(X, w) + b

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape))**2/2

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

loss = squared_loss
net = linreg
lr = 0.03
epoches = 3

for epoch in range(epoches):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w, b], lr, batch_size)
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch: {epoch + 1}, loss {float(train_l.mean()):f}')

5. 线性回归的简洁实现

import torch
from d2l import torch as d2l
from torch.utils import data

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

def load_array(batch_size, data_array, is_train=True):
    dataset = data.TensorDataset(*data_array)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array(batch_size, (features, labels))

from torch import nn

net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

loss = nn.MSELoss()

trainer = torch.optim.SGD(net.parameters(), lr = 0.03)

num_epoches = 3

iter_temp = iter(data_iter)

for epoch in range(num_epoches):
    for X, y in next(iter_temp):
        l = loss(net(X), y)
        l.backward()
        trainer.step()
        trainer.zero_grad()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/811586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt/C++音视频开发50-不同ffmpeg版本之间的差异处理

一、前言 ffmpeg的版本众多,从2010年开始计算的项目的话,基本上还在使用的有ffmpeg2/3/4/5/6,最近几年版本彪的比较厉害,直接4/5/6,大版本之间接口有一些变化,特别是一些废弃接口被彻底删除了,…

Arcis中三维面转二维面

1、如何查看面是三维面 打开面属性表,查看SHAPE字段,是带“ZM”的就是三维面 不带”ZM“的就是二维面 2、三维面转二维面 在转换的过程中,通过设置环境下的参数,可以转换

(杭电多校)2023“钉耙编程”中国大学生算法设计超级联赛(3)

1005 Out of Control 先将序列a升序,然后离散化 比如说序列a为1000 1000 500 200 10,然后升序后为10 200 500 1000 1000,映射到从1开始的数,为1 2 3 4 4,此即为前缀最大值序列,比如说5 3 4 6 7的前缀最大值序列为5 5 5 6 7 动态规划 f[i][j]表示长度为i的前缀最大值序列中,j为…

leetcode 1005. K 次取反后最大化的数组和

2023.7.30 本题思路如下: 按绝对值大小将数组进行从大到小的排序。遍历数组,若当前元素为负数则修改其符号。遍历完之后,判断k是否为奇数,若为奇数,则还需要修改一次符号,此时修改绝对值最小的那个数的符号…

opencv顺时针,逆时针旋转视频并保存视频

原视频 代码 import cv2# 打开视频文件 video cv2.VideoCapture(inference/video/lianzhang.mp4)# 获取原视频的宽度和高度 width int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))# 创建视频编写器并设置输出视频参数 fourcc …

无涯教程-jQuery - Dropable移动函数

Drop-able 功能可与JqueryUI中的交互一起使用。此功能可在任何DOM元素上启用可放置功能。 Drop able - 语法 $( "#droppable" ).droppable(); Drop able - 示例 以下是一个简单的示例&#xff0c;显示了drop-able的用法- <html><head><title>…

【数理知识】刚体 rigid body 及刚体的运动

文章目录 1 刚体2 刚体一般运动1 平移运动2 旋转运动 Ref 1 刚体 刚体是指在运动中和受力作用后&#xff0c;形状和大小不变&#xff0c;而且内部各点的相对位置不变的物体。绝对刚体实际上是不存在的&#xff0c;只是一种理想模型&#xff0c;因为任何物体在受力作用后&#…

Java ~ Collection/Executor ~ DelayQueue【总结】

前言 文章 相关系列&#xff1a;《Java ~ Collection【目录】》&#xff08;持续更新&#xff09;相关系列&#xff1a;《Java ~ Executor【目录】》&#xff08;持续更新&#xff09;相关系列&#xff1a;《Java ~ Collection/Executor ~ DelayQueue【源码】》&#xff08;学…

【玩转Python系列】【小白必看】使用Python爬取双色球历史数据并可视化分析

文章目录 前言导入库![在这里插入图片描述](https://img-blog.csdnimg.cn/05ab496a2ac045e6ad0b175292462fac.png)发送请求给指定网址伪装自己发送请求并获取响应设置编码解析HTML并获取结果创建CSV文件并写入数据打印输出结果加载自定义字体绘制折线图完整代码 结束语 前言 本…

(学习笔记)matplotlib.pyplot模块下基本画图函数的整理

matplotlib版本&#xff1a;3.7.1 python版本&#xff1a;3.10.12 基本函数 matplotlib版本&#xff1a;3.7.1python版本&#xff1a;3.10.12 1. plt.plot()函数1.1 plt.plot(x, y)1.2 plt.plot(x, y, **kwargs) 2. plt.xlable(), plt.ylable()3. plt.title()4. plt.show()5.p…

Pytorch(二)

一、分类任务 构建分类网络模型 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数无需写反向传播函数&#xff0c;nn.Module能够利用autograd自动实现反向传播Module中的可学习参数可以通过named_parameters()返回迭代器 from torch import nn import torch.nn.f…

如何注册Ddns域名?用快解析新手也可以轻松搞定!

对于每一个上网的朋友来说&#xff0c;如果平时经常需要访问外网&#xff0c;就需要用到Ddns域名了&#xff0c;不过这个域名的注册比较麻烦&#xff0c;也没有那么容易&#xff0c;因此很多朋友对此也有很多的疑惑。那么&#xff0c;Ddns域名注册怎么操作呢&#xff1f;其实利…

直流电机的系统辨识——LZW

前言 本文采用基于最小二乘的线性辨识方法和基于Nonlinear ARX模型的非线性辨识方法对图1所示的直流电机进行系统辨识&#xff0c;并分别设计H∞控制器&#xff0c;分析比较控制效果。 图1 实验器材 目录 前言一、数据采集二、系统辨识1.最小二乘法&#xff08;线性辨识&#…

认识自动化测试

目录 简述 手动测试 自动化测试 测试类型 单元测试 集成测试 端到端测试&#xff08;E2E&#xff09; 快照测试 测试覆盖率 代码覆盖率 需求覆盖率 总结 自动化测试有以下几个概念&#xff1a; 单元测试集成测试E2E 测试快照测试测试覆盖率TDD 以及 BDD 等 简述 …

【消息中间件】原生PHP对接Uni H5、APP、微信小程序实时通讯消息服务

文章目录 视频演示效果前言一、分析二、全局注入MQTT连接1.引入库2.写入全局连接代码 二、PHP环境建立总结 视频演示效果 【uniapp】实现买定离手小游戏 前言 Mqtt不同环境问题太多&#xff0c;新手可以看下 《【MQTT】Esp32数据上传采集&#xff1a;最新mqtt插件&#xff08;支…

GAMES101 笔记 Lecture13 光线追踪1

目录 Why Ray Tracing?(为什么需要光线追踪&#xff1f;)Basic Ray Tracing Algorithm(基础的光线追踪算法)Ray Casting(光线的投射)Generating Eye Rays(生成Eye Rays) Recursive(Whitted-Styled) Ray Tracing Ray-Surface Intersection(光线和平面的交点)Ray Rquation(射线方…

盘点:查快递教程

在“寄快递”成为常态的当下&#xff0c;如何快速进行物流信息查询&#xff0c;是收寄人所关心的问题。在回答这个问题之前&#xff0c;首先我们要知道&#xff0c;物流信息查询&#xff0c;有哪些方法&#xff1f; 1、官网单号查询 知道物流公司和单号的情况下&#xff0c;直…

管理类联考——写作——论说文——实战篇——标题篇

角度3——4种材料类型、4个立意对象、5种写作态度 经过审题立意后&#xff0c;我们要根据我们的立意&#xff0c;确定一个主题&#xff0c;这个主题必须通过文章的标题直接表达出来。 标题的基本要求 主题清晰&#xff0c;态度明确 第一&#xff0c;阅卷人看到一篇论说文的标…

【动态规划part13】| 300.最长递增子序列、674.最长连续递增序列、718.最长重复数组

目录 &#x1f388;LeetCode 300.最长递增子序列 &#x1f388;LeetCode 674. 最长连续递增序列 &#x1f388;LeetCode 718. 最长重复子数组 &#x1f388;LeetCode 300.最长递增子序列 链接&#xff1a;300.最长递增子序列 给你一个整数数组 nums &#xff0c;找到其…

camund——2、cancelActivityInstance()与多实例下getActiveActivityIds()获取不到当前任务的节点。

在多实例&#xff08;会签或者并行网关&#xff09;时如果使用以下代码来进行驳回时&#xff0c;使用 **List activeActivityIds runtimeService.getActiveActivityIds(instanceId);**来获取当前活动的节点会出现获取不到情况。 runtimeService.createProcessInstanceModific…