001.从0开始实现线性回归(pytorch)

news2024/9/21 6:33:17

000动手从0实现线性回归

0. 背景介绍

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
在这里插入图片描述

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt

1. 生成数据集合(待拟合)

使用python生成待拟合的数据

num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels =  w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])

2.数据的分批量处理

def data_iter(batch_size, features, labels):
    num_example = len(labels)
    indices = list(range(num_example))
    random.shuffle(indices)
    for i in range(0, num_example, batch_size):
        j = torch.tensor( indices[i:min(i+ batch_size,num_example)])
        yield features.index_select(0,j) ,labels.index_select(0,j)

3. 模型构建及训练

3.1 定义模型:

def linreg(X, w, b):
    return torch.mm(X,w)+b

3.2 定义损失函数

def square_loss(y, y_hat):
    return (y_hat - y.view(y_hat.size()))**2/2

3.3 定义优化算法

def sgd(params , lr ,batch_size):
    for param in params:
        param.data  -= lr * param.grad / batch_size

3.4 模型训练

# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):
    for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):
        l = loss(net(X,w,b),y).sum()
        l.backward()
        sgd([w,b],lr,batch_size=batch_size)
        #梯度清零避免梯度累加
        w.grad.data.zero_()
        b.grad.data.zero_()
    train_l = loss(net(features,w,b),labels)
    print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))

epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053


基于pytorch的线性模型的实现

  1. 相关数据和初始化与上面构建相同
  2. 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):
    def __init__(self, n_feature):
        # 调用父类的初始化
        super(LinearNet,self).__init__()
        # Linear(输入特征数,输出特征的数量,是否含有偏置项)
        self.linera = nn.Linear(n_feature,1)

    def forward(self,x):
        y = self.linera(x)
        return y
#打印模型的结构:
net = LinearNet(num_input)
print(net) 
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
  1. 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
  1. 定义损失函数
loss = nn.MSELoss()

5.定义优化算法

import torch.optim as optim
optimizer =  optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
  1. 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):
    for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):
        output= net(X)
        l = loss(output,y.view(-1,1))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch %d ,loss: %f' %(epoch,l.item()) )

epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2151554.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

正点原子阿尔法ARM开发板-IMX6ULL(八)——串口通信(寄存器解释)(补:有源蜂鸣器)

文章目录 一、蜂鸣器(待,理解)1.1 第一行1.2 第二行1.3 第三行 二、串口原理2.1 通信格式2.2 UART寄存器 一、蜂鸣器(待,理解) 1.1 第一行 对于第一行,首先先到fsl_iomuxc文件里面寻找IOMUXC_S…

探索C语言与Linux编程:获取当前用户ID与进程ID

探索C语言与Linux编程:获取当前用户ID与进程ID 一、Linux系统概述与用户、进程概念二、C语言与系统调用三、获取当前用户ID四、获取当前进程ID五、综合应用:同时获取用户ID和进程ID六、深入理解与扩展七、结语在操作系统与编程语言的交汇点,Linux作为开源操作系统的典范,为…

01-Mac OS系统如何下载安装Python解释器

目录 Mac安装Python的教程 mac下载并安装python解释器 如何下载和安装最新的python解释器 访问python.org(受国内网速的影响,访问速度会比较慢,不过也可以去我博客的资源下载) 打开历史发布版本页面 进入下载页 鼠标拖到页面…

安装Kali Linux后8件需要马上安排的事

目录 一、更新升级 二、 编辑器 三、用户与权限 四、 下载TOR 五、下载终端 一、更新升级 sudo apt update -y && sudo apt upgrade -y && sudo apt autoremove 二、 编辑器 VScode或者vim;点击.deb就会下载了 一般都会下载到Downloads文件夹中…

煤矸石检测数据集(yolo)

yolo煤矸石检测 数据集 pt模型 界面, ✓3091张图片和txt标签,标签类别两类:“coal”、“rock”。 ✓适用于煤矸石识别,深度学习,机器学习,yolov5 yolov6 yolov7 yolov8 yolov9 yolov10,Python 煤…

YOLOv5模型部署教程

一、介绍 YOLOv5模型是一种以实时物体检测闻名的计算机视觉模型,由Ultralytics开发,并于2020年年中发布。它是YOLO系列的升级版,继承了YOLO系列以实时物体检测能力而著称的特点。 二、基础环境 系统:Ubuntu系统,显卡…

企业内网安全

企业内网安全 1.安全域2.终端安全3.网络安全网络入侵检测系统异常访问检测系统隐蔽信道检测系统 4.服务器安全基础安全配置入侵防护检测 5.重点应用安全活动目录邮件系统VPN堡垒机 6.蜜罐体系建设蜜域名蜜网站蜜端口蜜服务蜜库蜜表蜜文件全民皆兵 1.安全域 企业出于不同安全防…

详读西瓜书+南瓜书第3章——线性回归

在这里,我们来深入探讨线性模型的相关内容,这章涵盖了从基础线性回归到更复杂的分类任务模型。我们会逐步分析其数学公式和实际应用场景。 3.1 基本形式 线性模型的核心是通过属性的线性组合来预测结果。具体形式为: 其中,w 是…

基于深度学习的花卉智能分类识别系统

温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 传统的花卉分类方法通常依赖于专家的知识和经验,这种方法不仅耗时耗力,而且容易受到主观因素的影响。本系统利用 TensorFlow、Keras 等深度学习框架构建卷积神经网络&#…

PLC通信协议的转化

在自动化程序设计中,常常需要对通信协议进行相互转化。例如,某个控制器需要通过PLC控制设备的某个部件的运动,但PLC只支持ModbusTCP协议,而控制器只支持CanOpen通讯协议。这时,就需要一个网关进行通信协议的转化。网关…

Thymeleaf模版引擎

Thymeleaf是面向Web和独立环境的现代服务器端Java模版引擎,能够处理HTML、XML、JavaScript、CSS甚至纯文本。Thymeleaf旨在提供一个优雅的、高度可维护的创建模版的方式。为了实现这一目标,Thymeleaf建立在自然模版的概念上,将其逻辑注入到模…

VUE3配置路由(超级详细)

第一步创建vue3的项目

(八)使用Postman工具调用WebAPI

访问WebAPI的方法&#xff0c;Postman工具比SoapUI好用一些。 1.不带参数的get请求 [HttpGet(Name "GetWeatherForecast")] public IEnumerable<WeatherForecast> Get() {return Enumerable.Range(1, 5).Select(index > new WeatherForecast{Date DateT…

【TabBar嵌套Navigation案例-JSON的简单使用 Objective-C语言】

一、JSON的简单使用 1.我们先来看一下示例程序里边,产品推荐页面, 在我们这个产品推荐页面里面, 它是一个CollectionViewController,注册的是一个xib的一个类型,xib显示这个cell,叫做item,然后,这个邮箱大师啊,包括这个图标,以及这些东西,都是从哪儿来的呢,都是从…

NLP 主要语言模型分类

文章目录 ngram自回归语言模型TransformerGPTBERT&#xff08;2018年提出&#xff09;基于 Transformer 架构的预训练模型特点应用基于 transformer&#xff08;2017年提出&#xff0c;attention is all you need&#xff09;堆叠层数与原transformer 的差异bert transformer 层…

SpringBoot 项目如何使用 pageHelper 做分页处理 (含两种依赖方式)

分页是常见大型项目都需要的一个功能&#xff0c;PageHelper是一个非常流行的MyBatis分页插件&#xff0c;它支持多数据库分页&#xff0c;无需修改SQL语句即可实现分页功能。 本文在最后展示了两种依赖验证的结果。 文章目录 一、第一种依赖方式二、第二种依赖方式三、创建数…

低空经济刚需篇:各种道路不畅地区无人机吊装详解

低空经济作为近年来备受关注的新兴经济形态&#xff0c;其核心在于利用3000米以下的低空空域进行各种飞行活动&#xff0c;以无人机、电动垂直起降飞行器(eVTOL)等为载体&#xff0c;推动交通、物流、巡检、农林植保、应急救援等多领域的变革。在道路不畅的地区&#xff0c;无人…

信息安全数学基础(20)中国剩余定理

前言 信息安全数学基础中的中国剩余定理&#xff08;Chinese Remainder Theorem&#xff0c;简称CRT&#xff09;&#xff0c;又称孙子定理&#xff0c;是数论中一个重要的定理&#xff0c;主要用于求解一次同余式组。 一、背景与起源 中国剩余定理最早见于我国南北朝时期的数学…

存储 NFS

目录 1.存储的应用场景 2.存储分类 3.NFS服务组成 4.环境说明 ​编辑 5.服务端部署 6.NFS服务端的配置 7.NFS服务端本地进行测试 1.存储的应用场景 存储一般用于上传网站数据&#xff08;内容&#xff09;&#xff0c;一般用于在网站集群中。使用存储的话用户上传的…

推荐系统-电商直播 多目标排序算法探秘

前言&#xff1a; 电商直播已经成为电商平台流量的主要入口&#xff0c;今天我们一起探讨推荐算法在直播中所面临的核心问题和解决方案。以下内容参考阿里1688的技术方案整理完成。 一、核心问题介绍 在电商网站中&#xff0c;用户的主要行为是在商品上的行为&#xff0c;直播…