d2l学习_第三章线性回归/欠拟合过拟合/权重衰减

news2024/11/24 5:32:57

x.1 Linear Regression Theory

x.1.1 Model

线性回归的模型如下:

请添加图片描述

我们给定d个特征值 x 1 , x 2 , . . . , x d x_1, x_2, ..., x_d x1,x2,...,xd,最终产生输出yhat,我们产生的yhat要尽量拟合原来的值y,在这一拟合过程中我们通过不断修改 w 1 , . . . , w d 和 b w_1, ..., w_d和b w1,...,wdb来实现。

x.1.2 Stategy or Loss

如何评价这个拟合好不好呢,我们的loss/strategy选择为MSE,对于单个点的损失如下:

请添加图片描述
请添加图片描述

将全部的点都添加至损失,得到,

请添加图片描述

最终我们需要做的就是最小化Loss,如下:

请添加图片描述

x.1.3 Algorithm

我们使用什么algorithm/optimizer来最小化loss呢?这里采用了Minibatch Stochastic Gradient Descent,mini-batch SGD也是深度学习中最常用的方法。

请添加图片描述

x.1.4 Nerual Network

线性回归的过程类似于神经元的表达,多个输入产生一个输出,

请添加图片描述

请添加图片描述

x.2 Experiments

x.2.1 手撕一个Linear Regression*

在下面的内容中,只使用了torch的自动微分来实现Linear Regression,值得反复推敲。

'''
手撕一个线性回归,包括:
1. 构造真实线性回归式子
2. 初始化权重
3. 生成一个迭代器每次取batch_size个数据
4. 构造model线性回归
5. 构造cost funtion-MSE
6. 构造optimizer-SGD
7. 开始每个epoch的训练, 注意梯度何时更新: 
        先loss(model(), y)计算loss来构造计算图; backward()计算梯度参数grad; param-=lr*grad更新梯度; param.zero_()梯度变零; 循环。

线性回归简洁表示:

这其实是一个feature=2, n=1000, label.shape=1的二元线性回归问题y = a * x_1 + b * x_2 + c: 用1000个样本(x_1, x_2)来拟合出a, b, c.
线性回归的简洁实现
'''
import random
import torch


# 生成n = 1000组数据, label 1维, features 2维; => weight [2, 1]
# 初始化 weight 和 bias 的初始值
def synthetic_data(w, b, num_examples): 
    """⽣成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b 
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))


true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

print('features:', features[0],'\nlabel:', labels[0])


# 手撕一个DataLoader
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

# 让我们尝试使用iter取batch_size个data
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

# 初始化权重,并使用`requires_grad=True`开启其自动微分
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True) 
b = torch.zeros(1, requires_grad=True)

#  model
def linreg(X, w, b):
    """线性回归模型"""
    # return torch.matmul(X, w) + b
    return X.mm(w) + b

# cost function
def squared_loss(y_hat, y):
    """MSE"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

# optimizer to minimize cost function
def sgd(params, lr, batch_size):
    """mini batchsize SGD"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()


# 设置超参数
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss


# 开始训练
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y) # X和y的⼩批量损失
        # 因为l形状是(batch_size,1),⽽不是⼀个标量。l中的所有元素被加到⼀起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size) # 使⽤参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')


print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

在这里补充一下with torch.no_grad(),这个函数用于在上下文中取消梯度更新。详见https://blog.csdn.net/qq_43369406/article/details/131115578

x.2.2 Concise Implementation of Linear Regression 简明实现

在model部分,由于Linear线性层由于经常需要使用到,故现代Pytorch已经将其封装为了一个API函数即torch.nn.LeazyLinear。这个API只关注输出的全连接层的结点个数。

在Loss部分,用torch.nn.MSELoss代替。

在Optimizer部分,用torch.optim.SGD代替。

import numpy as np
import torch
from torch import nn
from d2l import torch as d2l


"""
1. define model

"""
class LinearRegression(d2l.Module):  #@save
    """The linear regression model implemented with high-level APIs."""
    def __init__(self, lr):
        super().__init__()
        self.save_hyperparameters()
        self.net = nn.LazyLinear(1) # The latter allows users to only specify the output dimension | Specifying input shapes is inconvenient
        self.net.weight.data.normal_(0, 0.01)
        self.net.bias.data.fill_(0)

# 使用bulit-in func `__call__` 实现forward
@d2l.add_to_class(LinearRegression)  #@save
def forward(self, X):
    return self.net(X)

"""
2. define loss

"""
@d2l.add_to_class(LinearRegression)  #@save
def loss(self, y_hat, y):
    fn = nn.MSELoss()
    return fn(y_hat, y)

"""
3. define optimizer

"""
@d2l.add_to_class(LinearRegression)  #@save
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), self.lr)

"""
4. training

"""
model = LinearRegression(lr=0.03)
data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)

@d2l.add_to_class(LinearRegression)  #@save
def get_w_b(self):
    return (self.net.weight.data, self.net.bias.data)
w, b = model.get_w_b()

print(f'error in estimating w: {data.w - w.reshape(data.w.shape)}')
print(f'error in estimating b: {data.b - b}')

3.3 欠拟合过拟合

待学习

3.4 权重衰减

待学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/624908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL】数据库的增删改查、备份、还原等基本操作

【MySQL】数据库的基本操作 一、创建数据库---create1.1 字符集与校验规则1.1.1 查看系统默认字符集以及校验规则1.1.2 默认方式建立数据库1.1.3 指定编码集建立数据库 1.2 建库的本质 二、查看数据库及其相关属性---show2.1 显示所有数据库2.2 显示数据库的创建语句3.2 显示目…

Yarn【多队列实例、任务优先级设置】

前言 我们知道,Hadoop常见的三种调度器:FIFO调度器(几乎不用,因为它是先来先服务)、容量调度器(Apache Hadoop 默认的调度器)、公平调度器(CDH默认调度器)。 其中&…

PyTorch实战7:咖啡豆识别--手动搭建VGG16

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍦 参考文章:365天深度学习训练营-第P7周:咖啡豆识别🍖 原作者:K同学啊|接辅导、项目定制 目录 一、 前期准备1. 设置GPU2. 导入数据3. 划分数…

Zookeeper节点操作

ZooKeeper的节点操作 ZooKeeper的节点类型 ZooKeeper其实也是一个分布式集群,其中维护了一个目录树结构,在这个目录树中,组成的部分是一个个的节点。ZooKeeper的节点可以大致分为两种类型: 短暂类型 和 持久类型 短暂类型ephemeral: 客户端…

Creating Add-in Hooks (C#)

本文介绍如何使一个文件在添加、检入、检出到库时,让add-in 程序在SOLIDWORKS PDM Professional 中通知到你。 注意: 因为 SOLIDWORKS PDM Professional 无法强制重新加载Add-in程序 ,必须重新启动所有客户端计算机,以确保使用最…

电力综合自动化系统在煤矿领域的设计与应用

安科瑞虞佳豪 持续的高温,给能源保供带来严峻的考验。针对南方部分地区电力供应紧张的局面,煤炭资源大省山西,在确保安全生产的基础上,积极协调增产保供。 这几天,南方多地持续高温,用电量达到高峰。在山西…

深入理解深度学习——注意力机制(Attention Mechanism):Bahdanau注意力

分类目录:《深入理解深度学习》总目录 之前我们探讨了机器翻译问题: 通过设计一个基于两个循环神经网络的编码器—解码器架构, 用于序列到序列学习。 具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量&…

抖音seo矩阵系统源码搭建步骤分享

目录 账号矩阵系统源码搭建包括以下步骤: 二、代码实现 三、 代码展示 四、 服务交付 故障级别定义 服务响应时间 账号矩阵系统源码搭建包括以下步骤: 1. 准备服务器和域名 准备一台服务器,例如阿里云、腾讯云等。并在网站上购买一个域…

C++:类型转换

目录 一. C语言的类型转换 二. C类型转换 2.1 static_cast 2.2 reinterpret_cast 2.3 const_cast 2.4 dynamic_cast 三. 运行时类型识别 -- RTTI 四. 总结 一. C语言的类型转换 C语言的类型转换分为隐式类型转换和强制类型转换,隐式类型转换发生在相近的类…

WEB测试环境搭建和测试方法大全

一、WEB测试环境搭建 WEB测试时搭建测试环境所需的软硬件包括:电脑一台、JDK1.6、Tomcat7.0、mysql、IE浏览器、Firefox浏览器、Chrome浏览器、SVN客户端 通过SVN客户端导出最新的Web工程部署到Tomcat7.0下的webapps中,另外重要的一点就是修改数据库连…

31、js - Promise

一、Promise要点 -> js中,只有Promise对象才可以使用.then().catch()方法。 -> axios可以使用.then().catch(),完全是因为调用axios(),返回的是一个Promise对象。 -> new Promise() 里面的代码是同步代码,一旦调用promis…

这个API Hub太厉害了,太适合接口测试了,收录了钉钉企业微信等开放Api的利器

目录 前言: 01API Hub的项目 02API Hub 03调试 04 API 调试 05 API mock 06 针对开放项目功提供者 08 下载 前言: API Hub 的优势在于它提供了完整的 API 管理解决方案,包括API的设计、接口调试、测试和文档管理等。通过集中管理API…

火热报名中 | KCD 北京精彩抢“鲜”看

​ 仲夏已至,风云再起,Kubernetes Community Days 北京站英雄帖一经发出,云原生的各路英雄豪杰纷纷响应。经典招式的升级亮相,最新技巧的惊喜面世,且看各路门派京城聚首,掀起一场云原生的武林论道。各大议…

深入解析Cloudflare五秒盾与爬虫绕过技巧

最近一个朋友发现一个比较有趣的网站,他说正常构造一个HTTP请求居然拿不到网站页面的信息,网站页面如下: 别看它只是一个普普通通的小说网站。随后我在本地环境验证了一下,果不其然得到了以下信息: 从上面反馈的信息…

Yakit: 集成化单兵安全能力平台使用教程·进阶篇

Yakit: 集成化单兵安全能力平台使用教程进阶篇 1.数据处理数据对比Codec2.插件仓库1.数据处理 数据对比 该功能主要提供一个可视化的差异比对工具,用于分析两次数据之间的区别。使用场景可能包括:枚举用户名时比较登录成功和失败时服务器端反馈结果的差异、使用 Web Fuzzer…

【css3实现华为充电】那些你没想到的CSS效果之华为充电效果(附源码下载)

【写在前面】今天是高考的第二天,在这里我也祝各位学子能够旗开得胜,进入自己理想的大学,借着今天这个吉日我就和大家介绍一下如何用css实现华为充电效果。 涉及知识点:CSS3特效,华为充电特效实现,CSS属性f…

部署DR模式 LVS负载均衡群集

部署DR模式 LVS负载均衡群集 一、LVS-DR数据包流向分析二、DR模型的特点三、DR模式 LVS负载均衡群集部署 一、LVS-DR数据包流向分析 (1)客户端发送请求到 Director Server(负载均衡器),请求的数据报文(源 …

SYSU程设c++(第十五周)

vector容器 1.要开vector库 2.vector<T> 是动态的连续数组&#xff0c;可以列表初始化 vector<int> ivec(10, 2); //创建10个值为2的元素 3.可以靠[ ]、at(int)、front、back、迭代器访问其中元素&#xff0c;其中at会自动检查下标越界&#xff0c;抛出异常 4.迭…

【资料分享】ESD防护设计-常见ESD保护电路图

ESD防护设计 ESD防护设计的目的是&#xff0c;当集成电路任意两个输入/输出引脚之间发生ESD事件时&#xff0c;集成电路内部的ESD防护系统能及时开启来泄放掉大量的瞬时电流/电压,使内部电路免遭破坏。此外&#xff0c;在集成电路正常工作时&#xff0c;即未发生ESD事件时&…

Jmeter Suite安装中influx一直处于pending状态

目录 【前言】 【背景说明】 【问题表现】 【排查思路】 简单重试 深入分析 直面本质 【小结】 【写在最后】 完整版文档下载方式&#xff1a; 【前言】 今天要和大家聊聊一个关于Jmeter Suite安装的问题——“influx一直处于pending状态”。 作为一名老测试&#…