【动手学深度学习Pytorch】1. 线性回归代码

news2024/11/22 14:37:37

零实现

        导入所需要的包:

# %matplotlib inline
import random
import torch
from d2l import torch as d2l
import matplotlib.pyplot as plt
import matplotlib
import os

        构造人造数据集:假设w=[2, -3.4],b=4.2,存在随机噪音(均值为0,方差为0.001的正态分布噪声),函数拟合为y = w^{T}X + b + n。在构造数据集的过程中,首先X为正态分布(均值为0,方差为1,样本数/行数为num_examples,列数为len(w))

torch.normal(mean, std, *, generator=None, out=None):生成指定输出尺寸的正态分布随机数张量

torch.mv():矩阵和向量的乘积,此处X为矩阵,w为向量

def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w))) #均值为0方差为1的随机数,样本数,列数
    y = torch.mv(X, w) + b #y关于x的公式
    y += torch.normal(0, 0.001, y.shape) # 加入噪声项
    return X, y.reshape((-1,1)) #做成列向量返回
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

        查看数据集样本分布:

matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs):

        x,y:长度相同的数组,也就是我们即将绘制散点图的数据点,输入数据。

        s:点的大小,默认 20,也可以是个数组,数组每个参数为对应点的大小。

        c:点的颜色,默认蓝色 'b',也可以是个 RGB 或 RGBA 二维行数组。

        marker:点的样式,默认小圆圈 'o'。

        cmap:Colormap,默认 None,标量或者是一个 colormap 的名字,只有 c 是一个浮点数数组的时才使用。如果没有申明就是 image.cmap。

        norm:Normalize,默认 None,数据亮度在 0-1 之间,只有 c 是一个浮点数的数组的时才使用。

        vmin,vmax:亮度设置,在 norm 参数存在时会忽略。

        alpha:透明度设置,0-1 之间,默认 None,即不透明。

        linewidths:标记点的长度。

        edgecolors:颜色或颜色序列,默认为 'face',可选值有 'face', 'none', None。

        plotnonfinite:布尔值,设置是否使用非限定的 c ( inf, -inf 或 nan) 绘制点。

        **kwargs:其他参数。

detach():允许我们从计算图中分离出张量。当对一个张量调用detach()方法时,它会创建一个新的张量,这个新张量与原始张量共享数据,但它不再参与计算图的任何操作,对分离后的张量进行的任何操作都不会影响原始张量,也不会在计算图中留下任何痕迹。

plt.scatter(features[:,(1)].detach().numpy(),labels.detach().numpy(),1);
plt.show()

        遍历数据集,输出数据集内容:

len(): 返回对象(字符、列表、元组等)长度或项目个数(此处是张量的行数)

list(): 将元组转换为列表

range():创建一个整数列表

shuffle(): 随机打乱列表

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples)) #生成样本索引
    random.shuffle(indices) #样本随机读取没有特定顺序
    # 进行batch划分
    for i in range(0, num_examples, batch_size): #从i开始到i+batchsize
        batch_indices =  torch.tensor(indices[i:min(i + batch_size, num_examples)])
        # 截取切片:开始位置为i,结束位置为min函数的返回值
        # 返回值为i+batch_size和num_examples的值比较小的那个
        yield features[batch_indices], labels[batch_indices] #产生随机顺序的特征&标号

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

 

      定义参数、模型、损失函数以及优化算法:

torch.mutual():矩阵相乘

with torch.no_grad():所有计算得出的tensor的requires_grad都自动设置为False,不会进行自动求导

grad.zero_():将梯度置零(不然会发生累计的情况)

# 定义初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 定义模型
def linreg(X, w, b):
    return torch.matmul(X, w) + b
# 定义损失函数
def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape))**2/2
# 定义优化算法
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

        定义训练过程:

# 训练过程
lr = 0.01
num_epochs = 10
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w,b], lr, batch_size)
    with torch.no_grad():
        train_1= loss(net(features, w, b), labels)
        print(f'epoch{epoch + 1}, loss{float(train_1.mean()):f}')

简介实现

        导入所需要的包:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
import matplotlib.pyplot as plt

        创建人造数据集:

data.TensorDataset():将数据进行封装

data.DataLoader():将数据分批次处理

iter():获取列表的迭代器

next():获取下一个值

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))

初始化模型、模型参数、loss: 

nn.Sequential():实现模型层结构的简单排序

torch.optim.SGD():定义优化算法

torch.optim.SGD().step():进行模型的更新

# 使用框架的预定义好的层
from torch import nn
net = nn.Sequential(nn.Linear(2,1))
# 初始化模型参数
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
# 计算均方误差使用的是MSELoss类
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(),lr=0.01)

        定义训练过程:

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch{epoch + 1}, loss{1:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2245376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Keil基于ARM Compiler 5的工程迁移为ARM Compiler 6的工程

环境: keil版本为5.38,版本务必高于5.30 STM32F4的pack包版本要高于2.9 软件包下载地址:https://zhuanlan.zhihu.com/p/262507061 一、更改Keil中编译器 更改后编译,会报很多错,先不管。 二、更改头文件依赖 观察…

数据集-目标检测系列- 花卉 玫瑰 检测数据集 rose >> DataBall

数据集-目标检测系列- 花卉 玫瑰 检测数据集 rose >> DataBall DataBall 助力快速掌握数据集的信息和使用方式,会员享有 百种数据集,持续增加中。 贵在坚持! 数据样例项目地址: * 相关项目 1)数据集可视化项…

Linux驱动编程 - kmalloc、vmalloc区别

目录 前言: 1、区别 2、使用差异 一、kmalloc、kzalloc、kfree 1、动态申请 1.1 kmalloc() 1.2 kzalloc() 2、内存释放 3、示例 二、vmalloc、vzalloc、vfree 1、动态申请 1.1 vmalloc() 1.2 vzalloc() 2、内存释放 3、示例 前言: Linux内…

使用低成本的蓝牙HID硬件模拟鼠标和键盘来实现自动化脚本

做过自动化脚本的都知道,现在很多传统的自动化脚本方案几乎都可以被检测,比如基于root,adb等方案。用外置的带有鼠标和键盘功能集的蓝牙HID硬件来直接点击和滑动是非常靠谱的方案,也是未来的趋势所在。 一、使用蓝牙HID硬件的优势…

VideoCrafter模型部署教程

一、介绍 VideoCrafter是一个功能强大的AI视频编辑和生成工具,它结合了深度学习和机器学习技术,为用户提供了便捷的视频制作和编辑体验。 系统:Ubuntu22.04系统,显卡:4090,显存:24G 二、基础…

#渗透测试#SRC漏洞挖掘#Python自动化脚本的编写05之多线程与多进程

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…

C++多继承:一个子类继承多个父类的情况

C的类继承大家还算比较了解。它主要包括单继承、多继承、虚继承这几方面。 单继承就是一个子类只继承一个父类,多继承就是一个子类继承多个父类。 其实在C中,一个子类继承多个父类的情况还是比较常见的。比如,一个子类需要同时继承两个父类…

在windows电脑上安装docker服务

以下是在 Windows 电脑上安装 Docker 服务的详细步骤: 一、下载 Docker Desktop for Windows 系统要求:Windows 操作系统需要是 Windows 10(64 位)专业版、企业版或教育版,或者是 Windows 11。并且系统要开启了硬件虚…

单片机UART协议相关知识

概念 UART(Universal Asynchronous Receiver/Transmitter,通用异步收发传输器) 是一种 异步 串行 全双工 通信协议,用于设备一对一进行数据传输,只需要两根线(TX,RX)。 异步&…

XXL-JOB执行任务的SpringBoot程序无法注册到调度中心

文章目录 1. 问题呈现2. 问题产生的原因2.1 原因一:执行器和调度中心部署在不同的机器上2.2 原因二:调度中心部署在云服务器上 3. 解决方法3.1 方法一:将执行器和调度中心部署在同一台机器上3.2 方法二:手动指定执行器的ip地址&am…

Ettus USRP X410

总线连接器: 以太网 RF频率范围: 1 MHz 至 7.2 GHz GPSDO: 是 输出通道数量: 4 RF收发仪瞬时带宽: 400 MHz 输入通道数量: 4 FPGA: Zynq US RFSoC (ZU28DR) 1 MHz to 7.2 GHz,400 MHz带宽,GPS驯服OCXO,USRP软件无线电设备 Ettus USRP X410集…

哋它亢SEO技术分析:如何提升网站在搜索引擎中的可见性

文章目录 哋它亢SEO技术分析:如何提升网站在搜索引擎中的可见性网站的基本情况SEO优化分析与建议1. 元数据优化2. 关键词优化3. URL结构4. 图像优化5. 移动端优化6. 网站速度7. 结构化数据(Schema Markup)8. 内链与外链9. 社交分享 哋它亢SEO…

将网站地址改成https地址需要哪些材料

HTTPS(安全超文本传输协议)是HTTP协议的扩展。它大大降低了个人数据(用户名、密码、银行卡号等)被拦截的风险,还有助于防止加载网站时的内容替换,包括广告替换。 在发送数据之前,信息会使用SSL…

mongodb多表查询,五个表查询

需求是这样的,而数据是从mysql导入进来的,由于mysql不支持数组类型的数据,所以有很多关联表。药剂里找药物,需要药剂与药物的关联表,然后再找药物表。从药物表里再找药物与成分关联表,最后再找成分表。 这里…

端到端的专线管理与运维:实时掌握专线的运行状态

在当今高度信息化的时代,专线服务已成为企业数据传输的重要组成部分。为了确保专线服务的高效、稳定运行,我们采用了先进的端到端管理模式,对专线的运行状态和质量进行全面监控。本文将从专线管理的必要性、端到端管理模式的优势、实施步骤以…

SpringBoot(8)-任务

目录 一、异步任务 二、定时任务 三、邮件任务 一、异步任务 使用场景:后端发送邮件需要时间,前端若响应不动会导致体验感不佳,一般会采用多线程的方式去处理这些任务,但每次都需要自己去手动编写多线程来实现 1、编写servic…

PostgreSQL常用字符串函数与示例说明

文章目录 coalesce字符串位置(position strpos)字符串长度与大小写转换去掉空格(trim ltrim rtrim)字符串连接(concat)字符串替换简单替换(replace)替换指定位置长度(overlay)正则替换(regexp_replace) 字符串匹配字符串拆分split_part(拆分数组取指定位置的值)string_to_array…

深入剖析Java内存管理:机制、优化与最佳实践

🚀 作者 :“码上有前” 🚀 文章简介 :Java 🚀 欢迎小伙伴们 点赞👍、收藏⭐、留言💬 深入剖析Java内存管理:机制、优化与最佳实践 一、Java内存模型概述 1. Java内存模型的定义与作…

【论文速读】| RobustKV:通过键值对驱逐防御大语言模型免受越狱攻击

基本信息 原文标题:ROBUSTKV: DEFENDING LARGE LANGUAGE MODELS AGAINST JAILBREAK ATTACKS VIA KV EVICTION 原文作者:Tanqiu Jiang, Zian Wang, Jiacheng Liang, Changjiang Li, Yuhui Wang, Ting Wang 作者单位:Stony Brook University…

css使用弹性盒,让每个子元素平均等分父元素的4/1大小

css使用弹性盒,让每个子元素平均等分父元素的4/1大小 原本: ul {padding: 0;width: 100%;background-color: rgb(74, 80, 62);display: flex;justify-content: space-between;flex-wrap: wrap;li {/* 每个占4/1 */overflow: hidden;background-color: r…