二 动手学深度学习v2笔记 —— 线性回归 + 基础优化算法

news2025/1/23 9:20:10

二 动手学深度学习v2 —— 线性回归 + 基础优化算法


目录:

  1. 线性回归
  2. 基础优化方法

1. 线性回归
总结

  • 线性回归是对n维输入的加权,外加偏差
  • 使用平方损失来衡量预测值和真实值的差异
  • 线性回归有显示解
  • 线性回归可以看作是单层神经网络

2. 基础优化方法
梯度下降
在这里插入图片描述
小批量随机梯度下降
在这里插入图片描述

3. 总结

  • 梯度下降通过不断沿着反梯度方向更新参数求解
  • 小批量随机梯度下降是深度学习默认的求解算法
  • 两个重要的超参数是批量大小和学习率

4. 线性回归的从零开始实现

import torch
import random
def synthetic_data(w, b, num):
    "生成 y = Xw + b + 噪声"
    ''' 根据带有噪声的线性模型构造一个人造数据集。 我们使用线性模型参数𝐰 = [2,−3.4]⊤
    w=[2,−3.4]⊤、𝑏 = 4.2和噪声项𝜖ϵ生成数据集及其标签:𝐲 = 𝐗𝐰 + 𝑏 + 𝜖'''
    X = torch.normal(0, 1, (num, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape(-1, 1)

true_w = torch.tensor([2, -3.4])
true_b = 4.2
num_examples = 1000
batch_size = 10
features, labels = synthetic_data(true_w, true_b, num_examples)

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        num_indices = torch.tensor(indices[i: min(i +
                                                  batch_size, num_examples)])
        yield features[num_indices], labels[num_indices]

def linreg(X, w, b):
    return torch.matmul(X, w) + b

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape))**2/2

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

loss = squared_loss
net = linreg
lr = 0.03
epoches = 3

for epoch in range(epoches):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w, b], lr, batch_size)
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch: {epoch + 1}, loss {float(train_l.mean()):f}')

5. 线性回归的简洁实现

import torch
from d2l import torch as d2l
from torch.utils import data

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

def load_array(batch_size, data_array, is_train=True):
    dataset = data.TensorDataset(*data_array)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array(batch_size, (features, labels))

from torch import nn

net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

loss = nn.MSELoss()

trainer = torch.optim.SGD(net.parameters(), lr = 0.03)

num_epoches = 3

iter_temp = iter(data_iter)

for epoch in range(num_epoches):
    for X, y in next(iter_temp):
        l = loss(net(X), y)
        l.backward()
        trainer.step()
        trainer.zero_grad()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/817514.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4通道高速数据采集卡推荐哪些呢

FMC141是一款基于VITA57.4标准的4通道2.8GSPS/2.5GSPS/1.6GSPS采样率16位DA播放FMC子卡,该板卡为FMC标准,符合VITA57.4与VITA57.1规范,16通道的JESD204B接口通过FMC连接器连接至FPGA的高速串行端口。 该板卡采用TI公司的DAC39J84芯片&#x…

【玩转Linux】Linux输入子系统简介

(꒪ꇴ꒪ ),hello我是祐言博客主页:C语言基础,Linux基础,软件配置领域博主🌍快上🚘,一起学习!送给读者的一句鸡汤🤔:集中起来的意志可以击穿顽石!作者水平很有限,如果发现错误&#x…

LeetCode每日一题Day1——买卖股票的最佳时机

✨博主:命运之光 🦄专栏:算法修炼之练气篇(C\C版) 🍓专栏:算法修炼之筑基篇(C\C版) 🐳专栏:算法修炼之练气篇(Python版) ✨…

青大数据结构【2016】

一、单选 二、简答 3.简述遍历二叉树的含义及常见的方法。

shiro550反序列化漏洞原理与漏洞复现(基于vulhub,保姆级的详细教程)

漏洞原理 本文所有使用的脚本和工具都会在文末给出链接,希望读者可以耐心看到最后。 啥是shiro? Shiro是Apache的一个强大且易用的Java安全框架,用于执行身份验证、授权、密码和会话管理。使用 Shiro 易于理解的 API,可以快速轻松地对应用程序进行保…

代码随想录算法训练营之JAVA|第十七天| 654. 最大二叉树

今天是第17天刷leetcode,立个flag,打卡60天。 算法挑战链接 654. 最大二叉树https://leetcode.cn/problems/maximum-binary-tree/description/ 第一想法 错误的想法,就不说了。 看完代码随想录之后的想法 用递归模拟真实的过程 如果我…

【iOS】通知原理

我们可以通过看通知的实现机制来了解通知中心是怎么实现对观察者的引用的。由于苹果对Foundation源码是不开源的,我们具体就参考一下GNUStep的源码实现。GNUStep的源码地址为:GNUStep源码GitHub下载地址, 具体源码可以进行查看。 通知的主要流程 通知全…

AD21 PCB设计的高级应用(二)PCB常见走线等长设计

(二)PCB常见走线等长设计 1.蛇形线的等长设计2.DDR的等长分组3.等长的拓扑结构3.1 点对点连接3.2 T型拓扑结构3.3 菊花链拓扑结构 1.蛇形线的等长设计 在 PCB 设计中,网络等长调节目的就是为了尽可能地降低信号在 PCB上传输延迟的差异。在 Altium Desig…

C语言第十三课--------初阶指针的认识--------重要部分

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 🎂 作者介绍: 🎂🎂 🎂…

JavaSE类和对象(重点:this引用、构造方法)

目录 一、类的定义方式以及实例化 1.面向对象 Java是一门纯面向对象的语言(Object Oriented Program,简称OOP),在Java的世界里一切皆为对象。 2.类的定义和使用 1.在java中定义类时需要用到class关键字 3.类的实例化 4.类实例化的使用 二、this引用 …

面试中常聊 AMS,你是否又真的了解?

在面试的时候,经常会被问到这些问题: 对Activity的启动流程了解吗?AMS在Android起到什么作用,简单分析下Android的源码system_server为什么要在Zygote中启动,而不是由init直接启动呢?为什么要专门使用Zygote进程去孵…

有点慌,新公司项目构建用的Gradle

入职新公司,构建项目的工具用的gradle,以前没用过,看到一个build.gradle,点进去,心里一句我曹,这写的都是些什么玩意,方得一批,赶紧去补了下课。 好吧,先学点语法&#…

根据选择内容自动生成正则表达式

地址: https://regex.ai/ 如何使用? 比如我这里有个需求: 提取图片路径中的文件名 https://wx4.sinaimg.cn/mw1024/0040jbadly1hg5nk0l3gtj62c0340kjl02.jpg 提取出0040jbadly1hg5nk0l3gtj62c0340kjl02.jpg 数据越多越准确 左边提供数据, 右边给出需要提取的数据, 点击run&…

地产变革中,物业等风来

2023年7月,也许是中国房地产行业变局中的一个大拐点。 中信建投研报表示,政治局会议指出当前我国房地产形势已发生重大变化,要适时调整优化政策,为行业形势定调……当前房地产行业β已至。 不久前,国家统计局公布了2…

Mag-Fluo-4 AM,镁离子荧光探针,是一种有用的细胞内镁离子指示剂

资料编辑|陕西新研博美生物科技有限公司小编MISSwu​ PART1----产品描述: 镁离子荧光探针Mag-Fluo-4 AM,具细胞膜渗透性,对镁离子(Mg2) 和钙离子(Ca2)的 Kd 值分别是 4.7mM 和 22mM&#xff0c…

运维:Multipass软件让你的虚拟机管理更简单高效

一、Multipass是什么? 官网:https://multipass.run/ 一提到虚拟机大家一般都会想到VMvare和Virtual Box这两个的虚拟机软件,这两个软件一个比较麻烦的地方是安装完虚拟机以后还需要下载操作系统镜像。小编偶然间发现了Multipass。这款轻量级的…

Flowable-子流程-嵌套子流程

目录 定义图形标记XML内容使用示例视频讲解 定义 内嵌子流程又叫嵌入式子流程,它是一个可以包含其它活动、分支、事件,等的活动。我们通 常意义上说的子流程通常就是指的内嵌子流程,它表现为将一个流程(子流程)定义在…

【C语言初阶(20)】调试练习题

文章目录 前言实例1实例2 前言 在我们开始调试之前,应该有个明确的思路;程序是如何完成工作的、变量到达某个步骤时的值应该是什么、出现的问题大概会在什么位置。这些东西在调试之前都需要先确认下来,不然自己都不知道自己在调试个什么东西…

IT服务管理学习笔记<一>

### IT服务管理知识整理 ITSM 的核心思想是,IT 组织,不管它是企业内部的还是外部的,都是 IT 服务提供者,其 主要工作就是提供低成本、高质量的 IT 服务。 ITSM 的核心思想是,IT 组织,不管它是企业内部的还…

中国农村大学生学习了这个【React教程】迎娶导师女儿,出任CEO走上人生巅峰

注:最后有面试挑战,看看自己掌握了吗 文章目录 React创建一个简单的 JSX 元素创建一个复杂的 JSX 元素在 JSX 中添加注释渲染 HTML 元素为 DOM 树 🌸I could be bounded in a nutshell and count myself a king of infinite space. 特别鸣谢…