深度学习PyTorch 之 DNN-回归(多变量)

news2025/1/19 17:01:09

深度学习&PyTorch 之 DNN-回归中使用HR数据集进行了实现,但是HR数据集中只有一个变量,这里我们使用多变量在进行模拟一下

流程还是跟前面一样

数据导入
数据拆分
Tensor转换
数据重构
模型定义
模型训练
结果展示

1.1 数据导入

我们使用波士顿房价预测数据,这是个开源的数据集,所以通用性更强

data = pd.read_csv('./boston_house_prices.csv')
data

在这里插入图片描述

1.2 数据拆分

from sklearn.model_selection import train_test_split
train,test = train_test_split(data, train_size=0.7)

train_x = train[['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT']].values
test_x = test[['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT']].values

train_y = train.MEDV.values.reshape(-1, 1)
test_y = test.MEDV.values.reshape(-1, 1)

1.3 To Tensor

train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
train_y = torch.from_numpy(train_y).type(torch.FloatTensor)
test_y = torch.from_numpy(test_y).type(torch.FloatTensor)

1.4 数据重构

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(X, Y)
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)

train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)

test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds, batch_size=batch * 2)

与之前是一样的

1.5 网络定义

class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(13, 1)
        
    def forward(self, inputs):
        logits = self.linear(inputs)
        return logits

我们这里有13个特征变量

1.6 训练

model = LinearModel()
loss_fn = nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=lr) # 定义优化器

train_loss = []
train_acc = []

test_loss = []
test_acc = []


for epoch in range(epochs+1):
    model.train()
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
    if epoch%10==0:
        model.eval()
        with torch.no_grad():
            train_epoch_loss = sum(loss_fn(model(xb), yb) for xb, yb in train_dl)
            test_epoch_loss = sum(loss_fn(model(xb), yb) for xb, yb in test_dl)
        train_loss.append(train_epoch_loss.data.item() / len(train_dl))
        test_loss.append(test_epoch_loss.data.item() / len(test_dl))

        template = ("epoch:{:2d}, 训练损失:{:.5f}, 验证损失:{:.5f}")
    
        print(template.format(epoch, train_epoch_loss.data.item() / len(train_dl), test_epoch_loss.data.item() / len(test_dl)))
print('训练完成')

epoch: 0, 训练损失:469.15608, 验证损失:440.95737
epoch:10, 训练损失:101.80890, 验证损失:109.48333
epoch:20, 训练损失:91.18239, 验证损失:100.17014
epoch:30, 训练损失:100.83169, 验证损失:97.70323
epoch:40, 训练损失:89.96843, 验证损失:97.37273
epoch:50, 训练损失:94.20027, 验证损失:96.82300

epoch:480, 训练损失:74.97700, 验证损失:81.29946
epoch:490, 训练损失:74.74702, 验证损失:80.76858
epoch:500, 训练损失:89.31947, 验证损失:83.06767
训练完成

1.7 结果展示

import matplotlib.pyplot as plt

plt.plot(range(len(train_loss)), train_loss, label='train_loss')
plt.plot(range(len(test_loss)), test_loss, label='test_loss')
plt.legend()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/155308.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机原理三_进程管理

目录儿四、进程管理4.1 什么是进程4.1.1 进程的结构4.1.2 进程的特征4.1.3 进程与线程4.1.4 线程的实现方式用户级线程内核支持线程组合线程的调度4.2 进程是怎么运行的4.2.1 进程状态4.2.2 进程控制4.2.2.1 原语的概念4.2.2.2 挂起与激活4.2.3 进程调度4.2.3.1 调度层次4.2.3.…

CSS入门一、初识

零、文章目录 文章地址 个人博客-CSDN地址:https://blog.csdn.net/liyou123456789个人博客-GiteePages:https://bluecusliyou.gitee.io/techlearn 代码仓库地址 Gitee:https://gitee.com/bluecusliyou/TechLearnGithub:https:…

【BP靶场portswigger-服务端9】服务端请求伪造SSRF漏洞-7个实验(全)

前言: 介绍: 博主:网络安全领域狂热爱好者(承诺在CSDN永久无偿分享文章)。 殊荣:CSDN网络安全领域优质创作者,2022年双十一业务安全保卫战-某厂第一名,某厂特邀数字业务安全研究员&…

Episode 02 对称密码基础

一、从文字密码到比特序列密码 1、使用对称密钥进行加密 为了使原来的明文无法被推测出来,就要尽可能地打乱密文,这样才能达到加密的目的。密文打乱的是比特序列,无论是文本,图片还是音乐,只要能够将数据转换比特序列…

MSF后渗透持续后门

持续后门 ○ 利用漏洞取得的meterpreter shell运行于内存中,重启失效 ○ 重复exploit漏洞可能造成服务崩溃 ○ 持久后门保证漏洞修复后仍可远程控制 Meterpreter后门 run metsvc -A #删除-r use exploit/multi/handler set PAYLOAD windows/metsvc_bind_tcp se…

[22]. 括号生成

[22]. 括号生成题目算法设计:回溯算法设计:空间换时间题目 传送门:https://leetcode.cn/problems/generate-parentheses/ 算法设计:回溯 括号问题可以分成俩类: 括号的合法性判断,主要是用栈括号的合法生…

【自然语言处理】Word2Vec 词向量模型详解 + Python代码实战

文章目录一、词向量引入二、词向量模型三、训练数据构建四、不同模型对比4.1 CBOW4.2 Skip-gram 模型4.3 CBOW 和 Skip-gram 对比五、词向量训练过程5.1 初始化词向量矩阵5.2 训练模型六、Python 代码实战6.1 Model6.2 DataSet6.3 Main6.4 运行输出一、词向量引入 先来考虑一个…

IDEA远程快速部署SpringBoot项目到Docker环境

1.LInux上先安装docker环境 https://blog.csdn.net/YXWik/article/details/128643662 2.配置Docker远程连接端口 1. vim /usr/lib/systemd/system/docker.service 2. 找到ExecStar 在后面添加 -H tcp://0.0.0.0:2375 3. 退出编辑界面:先按esc,然后"…

【JAVA程序设计】(C00100)基于Springboot+html的前后端分离停车场管理系统

基于Springboothtml的前后端分离停车场管理系统项目简介项目获取开发环境项目技术运行截图项目简介 基于SpringBoothtml的前后端分离的停车场管理系统,本系统分为二种角色:管理员和收银员。 1.登录:管理员可以通过系统分配的账号…

Android 系统框架结构

目录 1.应用层(System Apps): 2.应用框架层(Java API Framework): 3.系统运行库层(Native): 4.硬件抽象层(HAL): 5.Linux内核层(Linux Kernel): 大部分开发的同学是不太清楚Android的系统的…

解决企业微信启动报错:0x0000142无法打开

解决企业微信启动报错:0x0000142无法打开1.问题描述2.问题查找3.问题解决4.事后感悟系统:Win10 WXWork:4.0.20.6020 1.问题描述 不知道从啥时候开始,打开企业微信会报错(见下图),报错代码是&am…

【Redis】缓存穿透问题及其解决方案

【Redis】缓存穿透问题及其解决方案 文章目录【Redis】缓存穿透问题及其解决方案1. 缓存穿透概念及原因2. 解决方案2.1 缓存空对象2.1.1 缓存空对象的优缺点2.1.2 改进代码2.2 布隆过滤2.2.1 布隆过滤的优缺点1. 缓存穿透概念及原因 缓存穿透:客户端请求的数据在 缓…

HTML与CSS基础(十)—— 综合项目

应用前面技术知识 完成小兔鲜儿项目设计图素材下载:链接: https://pan.baidu.com/s/1o5mWkgEfaTAA5spxMLuXEQ?pwdex7e 提取码: ex7e 一、Header 部分开发 布局分析:header布局分析:xtx-shortcut ①布局分析:xtx-shortcut ②布局分…

Hudi系列3:Hudi核心概念

文章目录Hudi架构一. 时间轴(TimeLine)1.1 时间轴(TimeLine)概念1.2 Hudi的时间线由组成1.3 时间线上的Instant action操作类型1.4 时间线上State状态类型1.5 时间线官网实例二. 文件布局三. 索引3.1 简介3.2 对比Hive没有索引的区别3.3 Hudi索引类型3.4 全局索引与非全局索引四…

数学建模-回归分析(Stata)

注意:代码文件仅供参考,一定不要直接用于自己的数模论文中国赛对于论文的查重要求非常严格,代码雷同也算作抄袭 如何修改代码避免查重的方法:https://www.bilibili.com/video/av59423231 //清风数学建模 一、基础知识 1.简介 …

不得不面对的随机MAC问题

一、现状 为了完善安全机制、保护用户隐私,各个设备厂商开发了 MAC 地址随机功能,防止用户信息泄露。随机 MAC 地址,就是一个随机生成的伪 MAC 地址,一个假 MAC 地址,使用随机 MAC 地址进行网络通信,而不是…

全网圣诞树最全完整源码下载合集【可下载】

文章目录一、全部源码打包下载:二、效果预览001-html版本 豪华动态圣诞树 抖音同款002-圣诞树灯光跟随音乐节拍一起呼吸点亮下雪动画效果代码003-圣诞树彩带飘动节日快乐效果代码004-圣诞树带音乐旋转拉伸动画效果005-python版本python取消延迟秒出图版 【全网最强无…

当FutureTask遇上DiscardPolicy,有坑

文章目录有啥坑呢?知识回顾问题触发条件问题复现问题分析问题修复扩展哈喽,你好,我是余数。今天来了解下当 FutureTask 遇上 DiscardPolicy 或 DiscardOldestPolicy 时容易掉的坑,然后分析分析问题产生的原因以及如何规避这类问题…

LVS+Keepalived+Nginx具体配置步骤

视频链接:4-6 搭建LVS-DR模式- 为两台RS配置虚拟IP_哔哩哔哩_bilibili 视频笔记链接:笔记 一、服务器与Ip约定 LVS DIP: 192.168.1.151 VIP: 192.168.1.150 Nginx1 RIP: 192.168.1.171 VIP: 192.168.1.150 Nginx2 RIP: 192.168.1.172 VIP: 192.168…

力扣 2283. 判断一个数的数字计数是否等于数位的值

题目 给你一个下标从 0 开始长度为 n 的字符串 num &#xff0c;它只包含数字。 如果对于 每个 0 < i < n 的下标 i &#xff0c;都满足数位 i 在 num 中出现了 num[i]次&#xff0c;那么请你返回 true &#xff0c;否则返回 false 。 示例 输入&#xff1a;num “1…