深度学习PyTorch 之 DNN-二分类

news2025/1/16 1:47:33

本节开始说一下DNN分类的pytorch实现,先说一下二分类

流程还是跟前面一样

数据导入
数据拆分
Tensor转换
数据重构
模型定义
模型训练
结果展示

代码

1 数据导入

我们使用最常见的iris数据集

data = pd.read_csv('./iris.csv')
data.columns = ["f1","f2","f3","f4","label"]

data = data.head(99)
data

在这里插入图片描述

因为iris鸢尾花数据集是一个三分类的数据,我们只去前99条数据,这样的话就只有两个分类了。

2.数据拆分

from sklearn.model_selection import train_test_split
train,test = train_test_split(data, train_size=0.7)

train_x = train[[c for c in data.columns if c != 'label']].values
test_x = test[[c for c in data.columns if c != 'label']].values

train_y = train.label.values.reshape(-1, 1)
test_y = test.label.values.reshape(-1, 1)

3.To Tensor

train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
train_y = torch.from_numpy(train_y).type(torch.FloatTensor)
test_y = torch.from_numpy(test_y).type(torch.FloatTensor)

train_x.shape, train_y.shape
#(torch.Size([69, 4]), torch.Size([69, 1]))

4.数据重构

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)

test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds, batch_size=batch * 2)

5.网络定义

from torch import nn
import torch.nn.functional as F

class DNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(4, 64)
        self.hidden2 = nn.Linear(64, 64)
        self.hidden3 = nn.Linear(64, 1)
    def forward(self, input):
        x = F.relu(self.hidden1(input))
        x = F.relu(self.hidden2(x))
        x = torch.sigmoid(self.hidden3(x))
        return x
#二分类准确率计算函数
def accuracy(out, yb):
    preds = (out>0.5).type(torch.IntTensor)
    return (preds == yb).float().mean()

def get_model():
    model = DNN()
    return model, torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.BCELoss()
model, opt = get_model()
model#查看网络结构

DNN(
(hidden1): Linear(in_features=4, out_features=64, bias=True)
(hidden2): Linear(in_features=64, out_features=64, bias=True)
(hidden3): Linear(in_features=64, out_features=1, bias=True)
)
我们也可以根据上节课内容可视化一下
在这里插入图片描述

6. 训练

train_loss = []
train_acc = []

test_loss = []
test_acc = []


for epoch in range(epochs+1):
    model.train()
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
    if epoch%1==0:
        model.eval()
        with torch.no_grad():
            train_epoch_loss = sum(loss_fn(model(xb), yb) for xb, yb in train_dl)
            test_epoch_loss = sum(loss_fn(model(xb), yb) for xb, yb in test_dl)
            acc_mean_train = np.mean([accuracy(model(xb), yb) for xb, yb in train_dl])
            acc_mean_val = np.mean([accuracy(model(xb), yb) for xb, yb in test_dl])
        train_loss.append(train_epoch_loss.data.item() / len(test_dl))
        test_loss.append(test_epoch_loss.data.item() / len(test_dl))
        train_acc.append(acc_mean_train)
        test_acc.append(acc_mean_val)
        template = ("epoch:{:2d}, 训练损失:{:.5f}, 训练准确率:{:.1f},验证损失:{:.5f}, 验证准确率:{:.1f}")
    
        print(template.format(epoch, train_epoch_loss.data.item() / len(test_dl), acc_mean_train*100, test_epoch_loss.data.item() / len(test_dl), acc_mean_val*100))
print('训练完成')

epoch: 0, 训练损失:3.09122, 训练准确率:57.0,验证损失:0.68206, 验证准确率:36.7
epoch: 1, 训练损失:2.87476, 训练准确率:54.3,验证损失:0.69797, 验证准确率:36.7
epoch: 2, 训练损失:2.62978, 训练准确率:61.0,验证损失:0.59363, 验证准确率:36.7
epoch: 3, 训练损失:2.30378, 训练准确率:100.0,验证损失:0.50508, 验证准确率:100.0
epoch: 4, 训练损失:2.05582, 训练准确率:100.0,验证损失:0.44803, 验证准确率:100.0
epoch: 5, 训练损失:1.76421, 训练准确率:100.0,验证损失:0.38924, 验证准确率:100.0
epoch: 6, 训练损失:1.54745, 训练准确率:100.0,验证损失:0.32642, 验证准确率:100.0

epoch:98, 训练损失:0.00304, 训练准确率:100.0,验证损失:0.00067, 验证准确率:100.0
epoch:99, 训练损失:0.00311, 训练准确率:100.0,验证损失:0.00067, 验证准确率:100.0
epoch:100, 训练损失:0.00300, 训练准确率:100.0,验证损失:0.00068, 验证准确率:100.0
训练完成

7.查看结果

import matplotlib.pyplot as plt
#损失值
plt.plot(range(len(train_loss)), train_loss, label='train_loss')
plt.plot(range(len(test_loss)), test_loss, label='test_loss')
plt.legend()

在这里插入图片描述

# 准确率
plt.plot(range(len(train_acc)), train_acc, label='train_acc')
plt.plot(range(len(test_acc)), test_acc, label='test_acc')
plt.legend()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/162421.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从校园到职场,听听他们的成长之路

背景介绍 这次分享主题是「从校园到职场 -- 我的成长之路」,视频内容可以查看 B 站链接:从校园到实习再到秋招。 上次的面试分享之后,阿卡拉提到关于刚毕业的学生也会有很多找工作的困扰,而且这个阶段能获取到的信息相对比较有限&…

Java API文档的使用详解

文章目录1. 概念2. 使用Java编程基础教程系列学会使用 API 文档是一个开发者基本的素养,而许多初学者并不会在意 API 文档的使用,甚至从来没有接触过,所以写下这篇文章探讨 API 文档的使用,希望能够帮助到你,先赞后看&…

正点原子嵌入式linux第二期

目录 第5讲 IMX6U芯片介绍 第6讲 6.1汇编LED驱动实验-原理分析 6.2 汇编LED驱动实验-汇编基本语法 ​编辑6.3 驱动编写 6.4 编写驱动 6.5烧写bin文件到SD卡并运行 第七讲 IMX启动方式(没怎么听懂) 7.1启动设备的选择 7.2 IVT表和BootData详解 7.3D…

从面试官的角度带你从源码分析关于vue(v2.7.10)的面试题

我们在面试的时候经常会被问到vue框架的原理类问题,我今天整理了一些常见问题和答案,希望有不正确之处还请指正。 1.new Vue时发生了什么 首先实例化一个对象,该对象执行init方法初始化生命周期等等,随后执行$mount方法开始生成v…

时间序列模型SCINet(代码解析)

前言 SCINet模型,精度仅次于NLinear的时间序列模型,在ETTh2数据集上单变量预测结果甚至比NLinear模型还要好。在这里还是建议大家去读一读论文,论文写的很规范,很值得学习,论文地址SCINet模型Github项目地址&#xff…

SpringBoot文件上传功能实现、异常处理

目录 一、文件上传 1、页面表单 2、文件上传代码 3、自动配置原理 二、异常处理 错误处理 1、默认规则 2、定制错误处理逻辑 3、异常处理自动配置原理 4、异常处理步骤流程 一、文件上传 1、页面表单 <form method"post" action"/upload" e…

详细实例说明+典型案例实现 对递归法进行全面分析 | C++

第二章 递归法 目录 ●第二章 递归法 ●前言 ●一、递归法是什么&#xff1f; 1.简要介绍 2.生活实例 ●二、递归法的典型案例——阶乘函数&斐波那契数列 1.阶乘函数 2.斐波那契数列 ●总结 前言 简单的来说&#xff0c;算法就是用计算机程序代码来实现数学…

static关键字的作用

目录 C语言中static关键字的作用 1.static关键字修饰局部变量 2.static关键字修饰全局变量 3.static关键字修饰函数 在C中static关键的作用 1.静态成员变量 2.静态成员函数 C语言中static关键字的作用 1.static关键字修饰局部变量 概念&#xff1a; static修饰局部变量就…

这才是2023年csdn最系统的网络安全学习路线(建议收藏)

01 什么是网络安全 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域&#xff0c;都有攻与防两面…

2. 认识O(logN)的排序

1. 递归 递归arr[L…R]范围上求最大值 流程分析如下: java代码: package paixu.class01;public class Code08_GetMax {public static void main(String[] args) {int[] arr {3,2,5,6,7,4};System.out.println(getMax(arr));}public static int getMax(int[] arr) {return p…

浅谈非类型模板参数、模板的特化

非类型模板参数 1.模板参数分类类型形参与非类型形参。 2.类型形参即&#xff1a;出现在模板参数列表中&#xff0c;跟在class或者typename之类的参数类型名称。类型参数也可以给缺省值 3.非类型形参&#xff0c;就是用一个常量作为类(函数)模板的一个参数&#xff0c;在类(函…

FARO RevEng Software 22.3.9 Crack

FARO RevEng Software 22.3.9 用于反向工程的三维点云捕捉和网格生成 先进的 FARO RevEng 软件平台能为用户带来全面的数字设计体验。该反向工程软件有助于利用三维点云创建和编辑高质量的网格和 CAD 表面&#xff0c;以实现反向工程工作流程。然后&#xff0c;工业设计师可以利…

Zookeeper 集群安装

Zookeeper 集群 主机 IP SoftWare Port OS Myidnode1 192.168.230.128 apache-zookeeper-3.7.1 2181 Centos 7 1 node2 192.168.230.129 apache-zookeeper-3.7.1 2181 Centos 7 2 node3 192.168.230.130 apache-zookeeper-3.7.1 2181 Centos 7 31. 下载 Apache Downloads 下…

2022简单一年

牙齿软件决定开发的时间是2021年底&#xff0c; 老板说2022年任务是要开发牙齿咬合力的软件&#xff0c; 功能主要借鉴美国的一款软件,老板给了我一份软件的说明书&#xff0c; 包含了软件的所有功能。 看到软件第一感觉是&#xff0c; 做的确实是牛逼&#xff0c; 并且各…

【十 二】Netty 文件传输

概念介绍 文件是最常见的数据源之一&#xff0c;程序经常需要在文件中读取数据&#xff0c;也要将数据保存在文件中&#xff0c;进行持久化。 文件是计算机中一种基本的数据存储形式。即使计算机关机&#xff0c;文件的数据还是存在的&#xff0c;但是内存的数据就会丢失。 相…

联合证券|五定增项目同日被否 保荐机构该不该“背锅”?

一天之内5家上市公司定增一起被拒&#xff0c;这一音讯瞬间引发商场重视。 1月11日&#xff0c;浙江世宝、铭普光磁、胜华新材、日辰股份、振华科技等5家上市公司一起公告称&#xff0c;定增不被证监会受理&#xff0c;理由均是证监会以为请求资料不符合法定方式。 投行业界人…

18.Isaac教程--坐标系

坐标系 本节介绍相机、网格/矩阵/图像和机器人坐标系。 ISAAC教程合集地址: https://blog.csdn.net/kunhe0512/category_12163211.html 文章目录坐标系网格像素中心网格/矩阵/图像坐标系相机坐标系机器人坐标系网格像素中心 存储网格 GGG 上的值&#xff0c;使得网格单元将值…

Crack:CAD Exchanger GUI/CAD Exchanger Lab 不是SDK

CAD Exchanger GUI/CAD Exchanger Lab 用于查看、转换和分析 CAD、BIM 和 3D 数据 在 Windows、Mac 和 Linux 上加载和转换模型&#xff0c;而无需处理昂贵的 CAD 系统。 使用 CATIA、SOLIDWORKS、Creo、STEP、JT、IFC 和更多格式。 非常适合您的 3D 数据工作流程 连接不同的软…

多频电磁仪在2018年杭州电磁大会的报告(ICEEG)

本篇是对多频电磁方法,应用的解读。 本汇报讲述了EMI传感器的基本情况,以及用手持多频电磁仪进行实际探测应用的例子。 什么是电磁感应?用发射装置(TX)激发谐波,产生一次场(Primary field),地下导体目标会相应产生涡流电磁场,产生二次场,被接收装置(RX)探测到。 …

超参数、划分数据集、偏差与方差、正则化

目录1.超参数(hyperparameters)参数(Parameters)&#xff1a;&#xff0c;&#xff0c;&#xff0c;&#xff0c;&#xff0c;...超参数&#xff1a;能够控制参数W,b的参数&#xff0c;是在开始学习之前设置的参数。比如&#xff1a;学习率、梯度下降循环的数量#iterations、隐…