【深度学习实验】线性模型(五):使用Pytorch实现线性模型:基于鸢尾花数据集,对模型进行评估(使用随机梯度下降优化器)

news2025/1/10 23:50:32

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 线性模型linear_model

2. 损失函数loss_function

3. 鸢尾花数据预处理

4. 初始化权重和偏置

5. 优化器

6. 迭代

7. 测试集预测

8. 实验结果评估

9. 完整代码


一、实验介绍

        线性模型是机器学习中最基本的模型之一,通过对输入特征进行线性组合来预测输出。本实验旨在展示使用随机梯度下降优化器训练线性模型的过程,并评估模型在鸢尾花数据集上的性能。

 二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入库

import torch
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import metrics
  • PyTorch
    • 优化器模块(optim
  • scikit-learn
    • 数据模块(load_iris)
    • 数据划分(train_test_split)
    • 评估指标模块(metrics

1. 线性模型linear_model

        该函数接受输入数据x,使用随机生成的权重w和偏置b,计算输出值output。这里的线性模型的形式为 output = x * w + b

def linear_model(x):
    return torch.matmul(x, w) + b

2. 损失函数loss_function

      这里使用的是均方误差(MSE)作为损失函数,计算预测值与真实值之间的差的平方。

def loss_function(y_true, y_pred):
    loss = (y_pred - y_true) ** 2
    return loss

3. 鸢尾花数据预处理

  • 加载鸢尾花数据集并进行预处理

    • 将数据集分为训练集和测试集

    • 将数据转换为PyTorch张量

iris = load_iris()
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
x_test = torch.tensor(x_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)

4. 初始化权重和偏置

w = torch.rand(1, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

5. 优化器

        使用随机梯度下降(SGD)优化器进行模型训练,指定学习率和待优化的参数w, b。

optimizer = optim.SGD([w, b], lr=0.01) # 使用SGD优化器

        

6. 迭代

num_epochs = 100
for epoch in range(num_epochs):
    optimizer.zero_grad()       # 梯度清零
    prediction = linear_model(x_train, w, b)
    loss = loss_function(y_train, prediction)
    loss.mean().backward()      # 计算梯度
    optimizer.step()            # 更新参数

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.mean().item()}")
  • 在每个迭代中:

    • 将优化器的梯度缓存清零,然后使用当前的权重和偏置对输入 x 进行预测,得到预测结果 prediction

    • 使用 loss_function 计算预测结果与真实标签之间的损失,得到损失张量 loss

    • 调用 loss.mean().backward() 计算损失的平均值,并根据计算得到的梯度进行反向传播。

    • 调用 optimizer.step() 更新权重和偏置,使用优化器进行梯度下降更新。

    • 每隔 10 个迭代输出当前迭代的序号、总迭代次数和损失的平均值。

7. 测试集预测

        在测试集上进行预测,使用训练好的模型对测试集进行预测

with torch.no_grad():
    test_prediction = linear_model(x_test, w, b)
    test_prediction = torch.round(test_prediction) # 四舍五入为整数
    test_prediction = test_prediction.detach().numpy()

8. 实验结果评估

  • 使用 metrics 模块计算分类准确度(accuracy)、精确度(precision)、召回率(recall)和F1得分(F1 score)。
  • 输出经过优化后的参数 w 和 b,以及在测试集上的评估指标。
accuracy = metrics.accuracy_score(y_test, test_prediction)
precision = metrics.precision_score(y_test, test_prediction, average='macro')
recall = metrics.recall_score(y_test, test_prediction, average='macro')
f1 = metrics.f1_score(y_test, test_prediction, average='macro')
print("The optimized parameters are:")
print("w:", w.flatten().tolist())
print("b:", b.item())

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

        本实验使用随机梯度下降优化器训练线性模型,并在鸢尾花数据集上取得了较好的分类性能。实验结果表明,经过优化后的模型能够对鸢尾花进行准确的分类,并具有较高的精确度、召回率和F1得分。

9. 完整代码

import torch
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import metrics

def linear_model(x, w, b):
    return torch.matmul(x, w) + b

def loss_function(y_true, y_pred):
    loss = (y_pred - y_true) ** 2
    return loss

# 加载鸢尾花数据集并进行预处理
iris = load_iris()
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
x_test = torch.tensor(x_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)

w = torch.rand(x_train.shape[1], 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
optimizer = optim.SGD([w, b], lr=0.01) # 使用SGD优化器

num_epochs = 100
for epoch in range(num_epochs):
    optimizer.zero_grad()       # 梯度清零
    prediction = linear_model(x_train, w, b)
    loss = loss_function(y_train, prediction)
    loss.mean().backward()      # 计算梯度
    optimizer.step()            # 更新参数

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.mean().item()}")

# 在测试集上进行预测
with torch.no_grad():
    test_prediction = linear_model(x_test, w, b)
    test_prediction = torch.round(test_prediction) # 四舍五入为整数
    test_prediction = test_prediction.detach().numpy()

accuracy = metrics.accuracy_score(y_test, test_prediction)
precision = metrics.precision_score(y_test, test_prediction, average='macro')
recall = metrics.recall_score(y_test, test_prediction, average='macro')
f1 = metrics.f1_score(y_test, test_prediction, average='macro')
print("The optimized parameters are:")
print("w:", w.flatten().tolist())
print("b:", b.item())

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1021455.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大语言模型的机遇和挑战

自然语言处理包含自然语言理解和自然语言生成两个方面, 常见任务包括文本分类, 结构分析 (词法分析, 分词, 词性标注, 句法分析, 篇章分析), 语义分析, 知识图谱, 信息提取, 情感计算, 文本生成, 自动文摘, 机器翻译, 对话系统, 信息检索和自动问答等. 在神经网络方法出现之前,…

Vue3_vite

使用Vue-cli创建 使用vite创建 Composition API 组合API setup 1.Vue3中的一个新的配置项,值为一个函数 2.可以将组件中所用到的数据,方法等配置在setup中. 3.setup函数的两种返回值 3.1若返回一个对象,则对象中的属性,方法,在模板中均可以直接使用. 3.2若返回一个渲染函数…

Leetcode.337 打家劫舍 III

题目链接 Leetcode.337 打家劫舍 III mid 题目描述 小偷又发现了一个新的可行窃的地区。这个地区只有一个入口,我们称之为 root 。 除了 root 之外,每栋房子有且只有一个“父“房子与之相连。一番侦察之后,聪明的小偷意识到“这个地方的所有…

「聊设计模式」之建造者模式(Builder)

🏆本文收录于《聊设计模式》专栏,专门攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎持续关注&&收藏&&订阅! 前言 设计模式是众多优秀软件开发实践的总结和提炼,…

STM32 ADC介绍和应用

目录 1.ADC是什么? 2.ADC的性能指标 3.ADC特性 4.ADC通道 5.ADC转换顺序 6.ADC触发方式 7.ADC转化时间 8.ADC转化模式 扫描模式 单次转换/连续转换 9.ADC实验 使用ADC读取烟雾传感器的值 代码实现思路: 1.ADC是什么? 全称&#…

DMNet复现(一)之数据准备篇:Density map guided object detection in aerial image

一、生成密度图 密度图标签生成 采用以下代码,生成训练集密度图gt: import cv2 import glob import h5py import scipy import pickle import numpy as np from PIL import Image from itertools import islice from tqdm import tqdm from matplotli…

UG NX二次开发(C#)-计算直线到各个坐标系轴向的投影角度

文章目录 1、前言2、需求分析3、NXOpen方法实现3.1 创建基准坐标系3.2 然后计算直线到基准坐标系的轴向角度3.3 代码调用4、测试效果为:1、前言 最近有个粉丝问我如何计算直线到坐标系各个轴向的角度,这里用UG NX二次开发(C#)实现。当然,这里的内容是经验之谈,如果有更好的…

基于matlab实现的船舶横摇运动仿真程序

完整程序: clc clear syms w we; w0.4:0.05:1.6;mu90;v6;%kb1;kt1;%航速6m/s,航向90度,即横浪,cos(90)0 T3;B10;Sw0.785;%船宽10米,吃水3米,水线面系数假设为0.785 weww.^2.*v/9.8; for i1:24 delta_we(i)we(i1)-…

【计算机网络】——数据链路层(应用:局域网、广域网、设备 )

//仅做个人复习和技术交流,图片取自王道考研,侵删 一、大纲 1、介质访问控制 信道划分介质访问控制 随机访问介质访问控制 2、局域网 3、广域网 4、数据链路层设备 二、局域网 1、局域网基本概念和体系结构 局域网(LocalArea Network): 简称LAN&…

Stable Diffusion - 采样器 DPM++ 3M SDE Karras 与 SDXL Refiner 测试

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/132978866 Paper: DPM-Solver: Fast Solver for Guided Sampling of Diffusion Probabilistic Models 扩散概率模型(DPMs)…

基于matlab实现的多普勒脉冲雷达回波仿真

完整程序: clear all;clc;close all; fc3e9; %载波频率 PRF2000; Br5e6; %带宽 fs10*Br; %采样频率 Tp5e-6; %脉宽 KrBr/Tp; %频率变化率 c3e8; %光速 lamda…

linux入门---共享内存

目录标题 共享内存的原理共享内存的理解shmget函数key和shmid的区别ipcs -m和shmctlshmatshmdt共享内存的通信共享内存的优点共享内存的缺点共享内存的特点 共享内存的原理 通过前面的内容我们知道不同的进程通过虚拟地址空间和页表能够将自己的数据映射到内存上的不同地方比如…

2023全新TwoNav开源网址导航系统源码 | 去授权版

2023全新TwoNav开源网址导航系统源码 已过授权 所有功能可用 测试环境:NginxPHP7.4MySQL5.6 一款开源的书签导航管理程序,界面简洁,安装简单,使用方便,基础功能免费。 TwoNav可帮助你将浏览器书签集中式管理&#…

Qt5开发及实例V2.0-第三章-Qt布局管理

Qt5开发及实例V2.0-第三章-Qt布局管理 第3章 Qt 5布局管理3.1 分割窗口QSplitter类3.2 停靠窗口QDockWidget类3.3 堆栈窗体QStackedWidget类3.4 基本布局(QLayout) 本章相关例程源码下载1.Qt5开发及实例_CH301.rar 下载2.Qt5开发及实例_CH302.rar 下载3.…

将json-bigint处理为数值分区数组的字段全部自动转为字符串

json-bigint虽然能帮我们处理好id 但 他的模式 显然不是直接可以用的 我们如果要到业务逻辑单独处理 那就太麻烦了 对系统也非常不友好 我们可以在vue项目中 src目录下创建一个utils 下面创建一个conversionLong.js 这个名字大家随便取 参考代码如下 var data {}; const Br…

黑马JVM总结(十四)

(1)分代回收_1 Java虚拟机都是结合前面几种算法,让他们协同工作,具体实现是虚拟机里面一个叫做分代的垃圾回收机制,把我们堆内存大的区域划分为两块新生代、老年代 新生代有划分为伊甸园、幸存区Form、幸存区To 为什…

Linux常用工具

文章目录 前言一、Linux编辑器-vim使用1.vim的基本概念2. vim的基本操作3. vim命令集1. 正常模式1. 模式切换和光标移动2. 删除文字及复制3. 其他操作 2. 底行模式 二、Linux编译器-gcc/g使用1. 命令和选项2. 预处理3. 编译4. 汇编(生成机器可识别代码)5. 连接(生成可执行文件或…

工业相机镜头选型相关内容参数(1)

工业相机镜头选型相关内容参数(1)https://www.bilibili.com/video/BV1PF411r7Yy/?spm_id_from333.999.0.0

C#通过重写Panel改变边框颜色与宽度的方法

在C#中,Panel控件是一个容器控件,用于在窗体或用户控件中创建一个可用于容纳其他控件的面板。Panel提供了一种将相关控件组合在一起并进行布局的方式。以下是Panel控件的详细使用方法: 在窗体上放置 Panel 控件: 在 Visual Studio 的窗体设计器中,从工具箱中拖动并放置一…

接口测试以及接口测试用例设计

1. 测试点 功能测试 单接口功能: 手工测试中的单个业务模块,一般对应一个接口 登录业务---->登录接口加入购物车业务---->加入购物车接口订单业务---->订单业务接口支付业务--->支付业务接口借助工具、代码以此绕开前端界面,组织接口所需要…