【深度学习实验】线性模型(三):使用Pytorch实现简单线性模型:搭建、构造损失函数、计算损失值

news2025/1/3 1:52:00

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 定义线性模型linear_model

2. 定义损失函数loss_function

3. 定义数据

4. 调用模型

5. 完整代码


一、实验介绍

  • 使用Pytorch实现
    • 线性模型搭建
    • 构造损失函数
    • 计算损失值

 二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        线性模型是一种基本的机器学习模型,用于建立输入特征与输出之间的线性关系。它是一种线性组合模型,通过对输入特征进行加权求和,再加上一个偏置项,来预测输出值。

        线性模型的一般形式可以表示为:y = w1x1 + w2x2 + ... + wnxn + b,其中y是输出变量,x1, x2, ..., xn是输入特征,w1, w2, ..., wn是特征的权重,b是偏置项。模型的目标是通过调整权重和偏置项,使预测值与真实值之间的差异最小化。

线性模型有几种常见的应用形式:

  1. 线性回归(Linear Regression):用于建立输入特征与连续输出之间的线性关系。它通过最小化预测值与真实值的平方差来拟合最佳的回归直线。

  2. 逻辑回归(Logistic Regression):用于建立输入特征与二分类或多分类输出之间的线性关系。它通过使用逻辑函数(如sigmoid函数)将线性组合的结果映射到概率值,从而进行分类预测。

  3. 支持向量机(Support Vector Machines,SVM):用于二分类和多分类问题。SVM通过找到一个最优的超平面,将不同类别的样本分隔开。它可以使用不同的核函数来处理非线性问题。

  4. 岭回归(Ridge Regression)和Lasso回归(Lasso Regression):用于处理具有多重共线性(multicollinearity)的回归问题。它们通过对权重引入正则化项,可以减小特征的影响,提高模型的泛化能力。

        线性模型的优点包括简单、易于解释和计算效率高。它们在许多实际问题中都有广泛的应用。然而,线性模型也有一些限制,例如对非线性关系的建模能力较弱。在处理复杂的问题时,可以通过引入非线性特征转换或使用核函数进行扩展,以提高线性模型的性能。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

0. 导入库

import torch

1. 定义线性模型linear_model

        该函数接受输入数据x,使用随机生成的权重w和偏置b,计算输出值output。这里的线性模型的形式为 output = x * w + b

def linear_model(x):
    w = torch.rand(1, 1, requires_grad=True)
    b = torch.randn(1, requires_grad=True)
    return torch.matmul(x, w) + b

2. 定义损失函数loss_function

      这里使用的是均方误差(MSE)作为损失函数,计算预测值与真实值之间的差的平方。

def loss_function(y_true, y_pred):
    loss = (y_pred - y_true) ** 2
    return loss

3. 定义数据

  • 生成一个随机的输入张量 x,形状为 (5, 1),表示有 5 个样本,每个样本的特征维度为 1。

  • 生成一个目标张量 y,形状为 (5, 1),表示对应的真实标签。

  • 打印数据的信息,包括每个样本的输入值x和目标值y
x = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):
    print("Item " + str(i), "x:", x[i][0], "y:", y[i])

4. 调用模型

  • 使用 linear_model 函数对输入 x 进行预测,得到预测结果 prediction

  • 使用 loss_function 计算预测结果与真实标签之间的损失,得到损失张量 loss

  • 打印了每个样本的损失值。
prediction = linear_model(x)
loss = loss_function(y, prediction)
print("The all loss value is:")
for i in range(len(loss)):
    print("Item ", str(i), "Loss:", loss[i])

5. 完整代码

import torch

def linear_model(x):
    w = torch.rand(1, 1, requires_grad=True)
    b = torch.randn(1, requires_grad=True)
    return torch.matmul(x, w) + b

def loss_function(y_true, y_pred):
    loss = (y_pred - y_true) ** 2
    return loss

x = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):
    print("Item " + str(i), "x:", x[i][0], "y:", y[i])

prediction = linear_model(x)
loss = loss_function(y, prediction)
print("The all loss value is:")
for i in range(len(loss)):
    print("Item ", str(i), "Loss:", loss[i])


注意:

        本实验的线性模型仅简单地使用随机权重和偏置,计算了模型在训练集上的均方误差损失,没有使用优化算法进行模型参数的更新。

        通常情况下会使用梯度下降等优化算法来最小化损失函数,并根据训练数据不断更新模型的参数,具体内容请听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017713.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5-1 Dataset和DataLoader

Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。 Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。 而DataLoader定义了按batch加载数据集的方法,它是…

无涯教程-JavaScript - EVEN函数

描述 EVEN函数返回四舍五入到最接近的偶数整数的数字。您可以使用此功能来处理两个项目。 语法 EVEN (number)争论 Argument描述Required/OptionalNumberThe value to round.Required Notes 如果数字为非数字,则EVEN返回#VALUE!错误值。 不管数字的符号如何,当从零开始调…

VisualStudio配置驱动远程部署

目标机器开启ping命令 默认情况下,Windows出于安全考虑不允许外部主机对其进行Ping测试。 允许ICMP回显 设置如下: 打开win7防火墙设置界面 左边的菜单中选择 【高级设置】 在弹出的 【高级安全 Windows 防火墙】 界面,选择 【入站规则】 …

Java常见面试题(含答案,持续更新中~~)

目录 1、JVM、JRE和JDK的关系 2、什么是字节码?采用字节码的最大好处是什么 3、Java和C的区别与联系 4、Java和GO的区别与联系 5、 和 equals 的区别是什么? 6、Oracle JDK 和 OpenJDK 的对比 7、String 属于基础的数据类型吗? 8、fi…

VHOST-SCSI代码分析(3)数据流处理

VHOST SCSI数据流如下所示: IO下发过程 虚拟机中应用态程序下发IO,依次经过VFS/文件系统层,BLOCK层,SCSI层,经VIRTIO SCSI驱动virtscsi_commit_rqs访问寄存器通知HOST内核中VHOST设备(VHOST KICK过程&#…

tolua源码分析(十一)代码生成

tolua源码分析(十一)代码生成 上一节我们分析了tolua中struct数据在lua和C#之间传递的过程,这一节我们来看一下tolua自动生成各种辅助代码的流程。 生成所有代码的入口位于ToLuaMenu.cs的GenLuaAll: [MenuItem("Lua/Genera…

达梦数据库-DW-国产化--九五小庞

武汉达梦数据库股份有限公司成立于2000年,是国内领先的数据库产品开发服务商,国内数据库基础软件产业发展的关键推动者。公司为客户提供各类数据库软件及集群软件、云计算与大数据等一系列数据库产品及相关技术服务,致力于成为国际顶尖的全栈…

读取yaml文件的值

记录一下,读取yaml文件中属性的值,这里用Kubernetes的deployment.yaml文件来举例。 读取yaml文件中的image的值 yaml文件 apiVersion: apps/v1 # 1.9.0 之前的版本使用 apps/v1beta2,可通过命令 kubectl api-versions 查看 kind: Deploy…

获取中文词组的汉语拼音首字母拼接

我们需要一个快捷批量处理&#xff1a;中文词组获取其汉语拼音首字母并拼接起来。 比如&#xff1a; 输出功率3&#xff1a;SCGL3 一鸣惊人&#xff1a;YMJR 我们可以采用字符字典法&#xff0c;穷举出所有的汉字【暂只考虑简体中文】 Dictionary<char,string> dict…

【数据分享】2006-2021年我国省份级别的市容环境卫生相关指标(20多项指标)

《中国城市建设统计年鉴》中细致地统计了我国城市市政公用设施建设与发展情况&#xff0c;在之前的文章中&#xff0c;我们分享过基于2006-2021年《中国城市建设统计年鉴》整理的2006—2021年我国省份级别的市政设施水平相关指标、2006-2021年我国省份级别的各类建设用地面积数…

【C++】STL—— unordered_map的介绍和使用、 unordered_map的构造函数和迭代器、 unordered_map的增删查改函数

文章目录 1. unordered_map的介绍2. unordered_map的使用2.1unordered_map的构造函数2.2unordered_map的迭代器2.3unordered_map的容量和访问函数2.4unordered_map的增删查改函数 1. unordered_map的介绍 unordered_map的介绍 &#xff08;1&#xff09;unordered_map是存储&l…

ElasticSearch(ES)简答了解

ES简介 Elasticsearch&#xff08;通常简称为ES&#xff09;是一个开源的分布式搜索和分析引擎&#xff0c;旨在处理各种类型的数据&#xff0c;包括结构化、半结构化和非结构化数据。它最初是为全文搜索而设计的&#xff0c;但随着时间的推移&#xff0c;它已经演变成一个功能…

web系统安全设计原则

一、前言 近日&#xff0c;针对西工大网络被攻击&#xff0c;国家计算机病毒应急处理中心和360公司对一款名为“二次约会”的间谍软件进行了技术分析。分析报告显示&#xff0c;该软件是美国国家安全局&#xff08;NSA&#xff09;开发的网络间谍武器。当下&#xff0c;我们发现…

【骑行之旅】昆明草海湿地公园和海晏村的美丽邂逅

这是一个九月的星期六&#xff0c;在昆明的大观公园门口&#xff0c;我们集合了一群热爱骑行的骑友。今天&#xff0c;阳光明媚&#xff0c;天空湛蓝&#xff0c;一切都充满了活力。我们的旅程从这里开始&#xff0c;一路向西&#xff0c;向着下一站&#xff0c;美丽的草海湿地…

虚拟机用户切换及设置root权限的密码

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

驱动开发,IO多路复用(select,poll,epoll三种实现方式的比较)

1.IO多路复用介绍 在使用单进程或单线程情况下&#xff0c;同时处理多个输入输出请求&#xff0c;需要用到IO多路复用&#xff1b;IO多路复用有select/poll/epoll三种实现方式&#xff1b;由于不需要创建新的进程和线程&#xff0c;减少了系统资源的开销&#xff0c;减少了上下…

从0到1搭建Halo博客系统教程

前期准备 云服务器&#xff0c;域名&#xff0c;命令工具&#xff08;这里使用是Mobaxterm&#xff09; 安装环境 宝塔面板 yum install -y wget && wget -O install.sh https://download.bt.cn/install/install_6.0.sh && sh install.sh ed8484bec在命令工…

【计算机视觉】Image Generation Models算法介绍合集

文章目录 一、Diffusion二、Guided Language to Image Diffusion for Generation and Editing&#xff08;GLIDE&#xff09;三、classifier-guidance四、Blended Diffusion五、DALLE 2六、AltDiffusion七、Group Decreasing Network八、Make-A-Scene九、Iterative Inpainting十…

c++ - 抽象类 和 使用多态当中一些注意事项

抽象类 纯虚函数 在虚函数的后面写上 0 &#xff0c;则这个函数为纯虚函数。 class A { public:virtual void func() 0; }; 纯虚函数不需要写函数的定义&#xff0c;他有类似声明一样的结构。 抽象类概念 我们把具有纯虚函数的类&#xff0c;叫做抽象类。 所谓抽象就是&a…

docker gitlab+jenkins搭建

一&#xff1a;gitlab搭建: 1&#xff1a;docker部署 2&#xff1a;修改root密码 3&#xff1a;创建普通账户 4&#xff1a;设置sshken 二&#xff1a;jenkins搭建 配置脚本 bash -x /var/jenkins_home/shell/game01.sh