使用PyTorch处理多维特征输入的完美指南

news2025/1/16 7:51:48

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦前期的回顾与准备
  • 🥦代码实现
  • 🥦总结

🥦引言

在机器学习和深度学习领域,我们经常会面对具有多维特征输入的问题。这种情况出现在各种应用中,包括图像识别、自然语言处理、时间序列分析等。PyTorch是一个强大的深度学习框架,它提供了丰富的工具和库,可以帮助我们有效地处理这些多维特征输入数据。在本篇博客中,我们将探讨如何使用PyTorch来处理多维特征输入数据。

🥦前期的回顾与准备

这里我们采用一组预测糖尿病的数据集,如下图
在这里插入图片描述
这里的每一行代表一个样本,同样的,每一列代表什么呢,代表一个特征,如下图。所以糖尿病的预测由下面这八个特征共同进行决定
在这里插入图片描述
按照过去的逻辑回归,应该是下图所示的,因为这是单特征值
在这里插入图片描述
但是现在由单特征值已经转变为多特征值了,所以我们需要对每个特征值进行处理,如下图
在这里插入图片描述
中间的特征值与权重的点乘可以从矩阵的形式进行表现
在这里插入图片描述
因为逻辑回归所以还有套一个Sigmoid函数,通常情况下我们将函数内的整体成为z(i)
在这里插入图片描述

注意: Sigmoid函数是一个按向量方式实现的

下面我们从矩阵相乘的形式进行展示,说明可以将一组方程合并为矩阵运算,可以想象为拼接哈。这样的目的是转化为并行运算,从而实现更快的运行速度。
在这里插入图片描述
所以从代码的角度去修改,我们只需要改变一下维度就行了

class Model(torch.nn.Module):
	def __init__(self):
		super(Model, self).__init__()
		self.linear = torch.nn.Linear(8, 1) 
		self.sigmoid = torch.nn.Sigmoid()
	def forward(self, x):
		x = self.sigmoid(self.linear(x)) 
		return x
model = Model()

这里的输入维度设置为8,就像上图中展示的x一样是N×8形式的矩阵,而 y ^ \hat{y} y^是一个N×1的矩阵。
这里我们将矩阵看做是一个空间变换的函数

我们可以从下图很好的展示多层神经网络的变换
在这里插入图片描述

从一开始的属于8维变为输出6维,再从输入的6维变为输出的4维,最后从输入的4维变为输出的1维。

如果从代码的角度去写,可以从下面的代码进行实现

class Model(torch.nn.Module):
	def __init__(self):
		super(Model, self).__init__()
		self.linear1 = torch.nn.Linear(8, 6) 
		self.linear2 = torch.nn.Linear(6, 4) 
		self.linear3 = torch.nn.Linear(4, 1) 
		self.sigmoid = torch.nn.Sigmoid()
	def forward(self, x):
		x = self.sigmoid(self.linear1(x)) 
		x = self.sigmoid(self.linear2(x)) 
		x = self.sigmoid(self.linear3(x)) 
		return x
model = Model()

这里我说明一下下面这条语句

  • self.sigmoid = torch.nn.Sigmoid():这一行创建了一个 Sigmoid 激活函数的实例,用于在神经网络的正向传播中引入非线性。

后面的前向计算就是一层的输出是另一层输入进行传,最后将 y ^ \hat{y} y^返回


同时我们的损失函数也没有变化,更新函数也没有变化,采用交叉熵和梯度下降
在这里插入图片描述

刘二大人这里没有使用Mini-Batch进行批量,后续的学习应该会更新
在这里插入图片描述

🥦代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np

# 载入Diabetes数据集
diabetes = datasets.load_diabetes()

# 将数据集拆分为特征和目标
X = diabetes.data  # 特征
y = diabetes.target  # 目标

# 数据预处理
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)  # 特征标准化

# 拆分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train).view(-1, 1)  # 将目标变量转换为列向量
X_test = torch.FloatTensor(X_test)
y_test = torch.FloatTensor(y_test).view(-1, 1)


# 构建包含多个线性层的神经网络模型
class DiabetesModel(nn.Module):
    def __init__(self, input_size):
        super(DiabetesModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)  # 第一个线性层
        self.fc2 = nn.Linear(64, 32)  # 第二个线性层
        self.fc3 = nn.Linear(32, 1)  # 最终输出线性层

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # ReLU激活函数
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 初始化模型
input_size = X_train.shape[1]
model = DiabetesModel(input_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 在测试集上进行预测
model.eval()
with torch.no_grad():
    y_pred = model(X_test)

# 计算性能指标
mse = nn.MSELoss()(y_pred, y_test)
print(f"均方误差 (MSE): {mse.item():.4f}")

运行结果如下
在这里插入图片描述

感兴趣的同学可以使用不同的激活函数一一测试一下

比如我使用tanh函数测试后得到的均方误差就小了许多
在这里插入图片描述

此链接是GitHub上的大佬做的可视化函数https://dashee87.github.io/deep%20learning/visualising-activation-functions-in-neural-networks/

🥦总结

这就是使用PyTorch处理多维特征输入的基本流程。当然,实际应用中,你可能需要更复杂的神经网络结构,更大的数据集,以及更多的调优和正则化技巧。但这个指南可以帮助你入门如何处理多维特征输入的问题,并利用PyTorch构建强大的深度学习模型。希望这篇博客对你有所帮助!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1081418.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WorkPlus私有部署即时通信助力企业信息安全与高效协作

在当今快速发展的商业环境中,高效的内部沟通对企业的成功至关重要。然而,在保障信息安全的同时,如何实现高效的协作和沟通一直是企业所面临的挑战。传统的公共即时通信平台,尽管提供了便利的沟通工具,但其数据存储和控…

python- excel 创建/写入/删sheet+花式遍历

文章目录 前言python- excel 创建/写入/删sheet花式遍历1. excel 创建2. 写入excel3. 创建写入excel demo实战4. 删除sheet5. excel 花式遍历 demo实战5.1. 获取 A1的值5.2. 获取指定列的切片数据,获取 B1到B5的值5.3. 循环整个excel的这个sheet5.4. 遍历指定行&…

VR全景云端看车,让你享受不一样的购车体验

这个“黄金周”可谓是热闹非凡,不仅房企大降价,部分车企也在“黄金周”发力,降价优惠促销量,那么你准备买车了吗?买车也是一个大件,需要多家店去走动比对价位,难免会挑花了眼,其实我…

中国SaaS行业等待“渡劫时刻”

期待、追捧、失望、质疑,中国SaaS行业激荡十几年,尝遍了市场浮沉,如今潮水褪去,SaaS企业们到了见真章的时刻。 一开始,SaaS行业被人们寄予厚望。互联网的蓬勃发展,数字化转型的历史进程,似乎都…

快讯|Tubi 有 Rabbit AI 啦

在每月一期的 Tubi 快讯中,你将全面及时地获取 Tubi 最新发展动态,欢迎星标关注【比图科技】微信公众号,一起成长变强! Tubi 推出 Rabbit AI 帮助用户找到喜欢的视频内容 Tubi 于今年九月底推出了 Rabbit AI,这是一项…

基于 gin框架搭建入门项目

go mod init gin-ranking go: creating new go.mod: module gin-ranking go: to add module requirements and sums:go mod tidy下载gin框架 cmd窗口中执行命令: go get -u github.com/gin-gonic/ginpackage mainimport ("github.com/gin-gonic/gin"&qu…

圭亚那奥罗拉金矿配电工程中AM5SE系列微机保护装置

安科瑞 崔丽洁 摘要:目前,微机保护装置广泛应用于电力系统中,该类装置能够有效监测电力系统的运行状况,并实时记录电力系统出现故障的位置及性质,从而为故障的快速处理提供有效的参考信息。本文介绍的AM5SE系列微机保…

android:can not find libdevapi.so

一、为什么会出现这样的报错? 引用了一些第三方的sdk的so库之后通常都会遇到这样的错误,(“nativeLibraryDirectories”[/data/app/com.lukouapp-1/lib/arm64, /vendor/lib64, /system/lib64]]] couldnt find "libxxxx.so"&#x…

zookeeper节点数据类型介绍及集群搭建

一、zookeeper介绍 zookeeper官网:Apache ZooKeeper zookeeper是一个分布式协调框架,保证的是CP,即一致性和分区容错性;zookeeper是一个分布式文件存储系统,文件节点可以存储数据,监听子文件节点等可以实…

绥化市中心广场焕发新活力:OLED透明拼接屏的奇观展示

OLED透明拼接屏技术在绥化城市的应用引起了广泛关注。 绥化市位于中国东北地区,是黑龙江省的一个重要城市。 该市拥有悠久的历史,历经多个朝代的兴衰。绥化的历史背景赋予了这座城市独特的文化底蕴和魅力。 绥化市内有许多著名景点,其中最…

day

#include <iostream> using namespace std; class Per {//算术运算符friend const Per operator(const Per &k1,const Per &k2);friend const Per operator-(const Per &k1,const Per &k2);friend const Per operator*(const Per &k1,const Per &…

vue项目npm intall时发生版本冲突的解决办法

在日常使用命令npm install / npm install XX下载依赖的操作中&#xff0c;我经常会遇到无法解析依赖树的问题&#xff08;依赖冲突&#xff09; 当遇到这种情况的时候&#xff0c;可以通过以下命令完成依赖安装&#xff1a; npm install --legacy-peer-deps npm install xxx…

国产化即时通讯平台WorkPlus,助力企业实现自主可控的沟通与协作

作为在线沟通与协作的重要工具&#xff0c;即时通讯平台在企业中扮演着不可或缺的角色。然而&#xff0c;为了保护企业的核心数据和机密信息&#xff0c;越来越多的企业开始转向国产化的即时通讯平台。在这一背景下&#xff0c;国内软件品牌WorkPlus应运而生&#xff0c;为企业…

基于 chinese-roberta-wwm-ext 微调训练中文命名实体识别任务

一、模型和数据集介绍 1.1 预训练模型 chinese-roberta-wwm-ext 是基于 RoBERTa 架构下开发&#xff0c;其中 wwm 代表 Whole Word Masking&#xff0c;即对整个词进行掩码处理&#xff0c;通过这种方式&#xff0c;模型能够更好地理解上下文和语义关联&#xff0c;提高中文文…

mongodb简介、安装、搭建复制集

一、 简介 NoSQL数据库四大家族&#xff1a;列存储 Hbase&#xff0c;键值(Key-Value)存储 Redis&#xff0c;图像存储 Neo4j&#xff0c;基于分布式文档存储的数据库MongoDb。 MongoDB 和关系型数据库对比 关系型数据库MongoDBdatabase&#xff08;库&#xff09;database&…

.net也能写内存挂

最近在研究.net的内存挂。 写了很久的c,发现c#写出来的东西实在太香。 折腾c#外挂已经有很长时间了。都是用socket和c配合。 这个模式其实蛮成功的&#xff0c;用rpc调用的方式加上c#的天生await 非常好写逻辑 类似这样 最近想换个口味。注入托管dll到非托管进程 这样做只…

基于 SOFAJRaft 实现注册中心

文章目录 1.前言2.git 示例地址3.官网示例分析3.SOFAJRAFT 注册中心实现&#xff08;服务端&#xff09;3.1 核心功能3.2 模块设计3.3 请求消息数据结构设计3.3.1 Registration 注册消息3.3.2 GetServiceInstancesRequest 获取服务实例请求3.3.3 GetServiceInstancesResponse 获…

Android中级——ListView和RecycleView解析

ListView和RecycleView ListViewRecycleView ListView 使用步骤可看Android基础——ListView&#xff0c;其setAdapter()如下&#xff0c;回调getCount()获取Item个数 Override public void setAdapter(ListAdapter adapter) {if (mAdapter ! null && mDataSetObserv…

v-model绑定input、textarea、checkbox、radio、select

1.input <div><!-- v-model绑定input --><input type"text" v-model"message"><h2>{{message}}</h2></div><script>const App{template:#my-app,data() {return {message:Hello World,}},}Vue.createApp(App).…

Java:设计模式之结构型-装饰者模式(decorator pattern)

装饰者模式(decorator pattern): 动态地将责任附加到对象上 意图&#xff1a;为对象动态添加功能 类图 实现 设计不同种类的饮料&#xff0c;饮料可以添加配料&#xff0c;比如可以添加牛奶&#xff0c;并且支持动态添加新配料。每增加一种配料&#xff0c;该饮料的价格就…