Pytorch_CPU鸢尾花lirsDataset 尝试

news2024/9/29 7:30:42

鸢尾花数据集(lris Dataset)

(1)下载地址【引用】:鸢尾花数据集下载

(2)鸢尾花数据集特点

茑尾花数据集有150 条样本记录,分为3个类别,每个类别有 50 个样本,每条记录有 4个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。

鸢尾花数据集(lris Dataset)数据加载

dataLoader.py

#导入相关模块
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset #用于创建可用于DataLoader的自定义数据集类

# 类IrisDataset 继承自 torch.utils.data.Dataset
# 表示这是一个pytorch数据集类
class IrisDataset(Dataset):
    def __init__(self, data_path):
        # 保存数据集的路径
        self.data_path = data_path
        # 使用assert 来检查data_path是否存在,如果路径无效,则抛出错误
        assert os.path.exists(data_path), "Dataset does not exist."

        # 使用pandas 读取CSV数据,其中将列名指定[0, 1, 2, 3, 4]
        df = pd.read_csv(self.data_path, names=[0, 1, 2, 3, 4])
        # 通过map将鸢尾花的三个类别(setosa、versicolor、virginica)映射为数字标签0、1、2
        label_mapping = {"setosa": 0, "versicolor": 1, "virginica": 2}
        df[4] = df[4].map(label_mapping)

        # 提取特征列(前4列)
        features = df.iloc[:, :4]
        # 提取类别标签(最后一列)
        labels = df.iloc[:, 4:]

        # Standardization (Z-score normalization)
        # 归一化(Z值化)处理,对数据进行标准化
        features = (features - np.mean(features) / np.std(features))

        # Convert data to tensors
        # 将数据转化为PyTorch张量
        self.features = torch.from_numpy(np.array(features, dtype="float32"))
        self.labels = torch.from_numpy(np.array(labels, dtype="int"))

        # Dataset size
        # 保存数据集的样本数量
        self.dataset_size = len(labels)
        print(f"Dataset size: {self.dataset_size}")

    # 获取数据集的长度
    def __len__(self):
        return self.dataset_size
    # 获取样本
    def __getitem__(self, index):
        return self.features[index], self.labels[index] # 返回指定索引的样本

使用PyTorch实现训练和评估流程

nn.py

import os.path
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

# 从 dataLoader 文件中导入自定义的 iris_dataloader类,用于加载鸢尾花数据集。
from dataLoader import iris_dataloader


# 定义一个名为 NeuralNetwork 的类继承自nn.Module,构建一个三层的全连接神经网络。
class NeuralNetwork(nn.Module):
    # 初始化函数,定义了1个输入层、2个隐藏层、1个输出层
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super().__init__()

        self.layer1 = nn.Linear(input_dim, hidden_dim1)
        self.layer2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.layer3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


# 设置计算设备,如有GPU则使用CUDA,否则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)


# 加载和划分数据集
# 加载鸢尾花数据集,并划分为训练集、验证集和测试集
iris_dataset = iris_dataloader("./iris.txt")
train_size = int(len(iris_dataset) * 0.7)
val_size = int(len(iris_dataset) * 0.2)
test_size = len(iris_dataset) - train_size - val_size

# random_split 根据比例 70%/20%/10% 将数据集划分为训练集、验证集和测试集。
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    iris_dataset, [train_size, val_size, test_size]
)
# DataLoader 用于加载数据,并设置批次大小和是否随机打乱数据。
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print(f"训练集大小: {len(train_loader) * 16}, 验证集大小: {len(val_loader)}, 测试集大小: {len(test_loader)}")


# 推断函数,计算模型在数据集上的准确率
def evaluate(model, data_loader, device):
    model.eval()
    correct_predictions = 0
    # torch.no_grad() 禁用梯度计算,从而加快推断速度并节省显存
    with torch.no_grad():
        for data in data_loader:
            inputs, labels = data
            outputs = model(inputs.to(device))
            if outputs.dim() > 1:
                # 提取模型输出中概率最大的类别
                predicted_labels = torch.max(outputs, dim=1)[1]
            else:
                predicted_labels = torch.max(outputs, dim=0)[1]
            correct_predictions += torch.eq(predicted_labels, labels.to(device)).sum().item()
    accuracy = correct_predictions / len(data_loader)
    return accuracy


# 主函数,执行模型训练与验证
def main(lr=0.005, epochs=20):
    model = NeuralNetwork(4, 12, 6, 3).to(device)
    loss_function = nn.CrossEntropyLoss()

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=lr)
    save_path = os.path.join(os.getcwd(), "results/weights")
    if not os.path.exists(save_path):
        os.makedirs(save_path)

# 模型训练:每个 epoch 开始时模型设置为训练模式(model.train()),计算损失并通过反向传播优化权重。
    for epoch in range(epochs):
        model.train()
        correct_predictions = torch.zeros(1).to(device)
        total_samples = 0

        train_progress = tqdm(train_loader, file=sys.stdout, ncols=100)
        for batch in train_progress:
            inputs, labels = batch
            labels = labels.squeeze(-1)
            total_samples = inputs.shape[0]

            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            predicted_labels = torch.max(outputs, dim=1)[1]
            correct_predictions += torch.eq(predicted_labels, labels.to(device)).sum()

            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            
            train_accuracy = correct_predictions / total_samples
            train_progress.set_description(f"train epoch[{epoch + 1}/{epochs}] loss:{loss:.3f}")
        # 验证集评估:在每个 epoch 结束时,通过 evaluate 函数计算验证集的准确率。
        val_accuracy = evaluate(model, val_loader, device)
        print(f"train epoch[{epoch + 1}/{epochs}] loss:{loss:.3f} train_acc:{train_accuracy:.3f} val_acc:{val_accuracy:.3f}")
        # 保存模型权重:每次迭代结束后保存模型参数。
        torch.save(model.state_dict(), os.path.join(save_path, "nn.pth"))

    print("Training completed.")
    test_accuracy = evaluate(model, test_loader, device)
    print(f"test_acc: {test_accuracy}")


if __name__ == "__main__":
    main(lr=0.005, epochs=20)

 运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2137089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习笔记JVM篇(一)

1、类加载的过程 加载->验证->准备->解析->初始化->使用->卸载 2、JVM内存组成部分(HotSpot) 名称作用特点元空间(JDK8之前在方法区)用于存储类的元数信息,例如名称、方法名、字段等;…

【程序分享】express 程序:可扩展的高级工作流程,用于更快速的从头算材料建模

分享一个 express 程序:可扩展的高级工作流程,用于更快速的从头算材料建模。 感谢论文的原作者! 主要内容 “在这项工作中,我们介绍了一个开源的Julia项目express,这是一个可扩展的、轻量级的、高通量的高级工作流框…

学python要下什么包吗,有推荐的教程或者视频吗?

初学者可以尝试三种方法来学习Python第三方库,第一种传统,第二种省心,第三种轻量。 1、安装PythonPycharm,通过pip进行包管理,或者Pycharm后台也可以 2、安装Anaconda,预装了几百个数据科学包&#xff0c…

模仿抖音用户ID加密ID的算法MB4E,提高自己平台ID安全性

先看抖音的格式 对ID加密的格式 MB4EENgLILJPeQKhJht-rjcc6y0ECMk_RGTceg6JBAA 需求是 同一个ID 比如 413884936367560 每次获取得到的加密ID都是不同的,最终解密的ID都是413884936367560 注意这是一个加密后可解密原文的方式,不是单向加密 那么如下进行…

Windows 环境下 vscode 配置 C/C++ 环境

vscode Visual Studio Code(简称 VSCode)是一个由微软开发的免费、开源的代码编辑器。它支持多种编程语言,并提供了代码高亮、智能代码补全、代码重构、调试等功能,非常适合开发者使用。VSCode 通过安装扩展(Extension…

abVIEW 可以同时支持脚本编程和图形编程

LabVIEW 可以同时支持脚本编程和图形编程,但主要依赖其独特的 图形编程 环境(G语言),其中程序通过连线与节点来表示数据流和功能模块。不过,LabVIEW 也支持通过以下方式实现脚本编程的能力: 1. 调用外部脚本…

第4步CentOS配置SSH服务用SSH终端XShell等连接方便文件上传或其它操作

宿主机的VM安装CENTOS文件无法快速上传,也不方便输入命令行,用SSH终端xshell连接虚拟机的SSH工具就方便多了,实现VM所在宿主机Win10上的xshell能连接vm的centos要实现以下几个环节 1、确保宿主机与虚拟机的连通性。 2、虚拟机安装SSH服务&…

ESP8266_MicroPython——GPIO_LED_KEY_外部中断

MicroPython 文章目录 MicroPython前言一、安装软件二、点亮第一颗LED灯三、KEY按键四、外部中断总结 前言 MicroPython比较简单但是没有系统的更新过文章,准备写一下ESP8266——MicroPython的文章做一个系列。 一、安装软件 安装开发软件 Thonny,安装…

豆包MarsCode编程助手:产品功能解析与应用场景探索!

随着现代技术的不断进化升级,人工智能正在逐步改变着我们的日常工作方式。特别是对于复杂的项目,代码编写、优化、调试、测试等环节充满挑战。为了简化这些环节、提高开发效率,许多智能编程工具应运而生,豆包MarsCode 编程助手就是…

瑞芯微Android6 内核编译报错解决方案

1、报错内容如下图所示 错误内容: Kernel: arch/arm/boot/zImage is ready make: *** [kernel.img] Error 127 2、分析与解决方法 由于之前在ubuntu环境下编译没问题,现在是在centos环境下重新编译的时候报错,所以经过分析对比两个环境的…

非关系型数据库Redis

文章目录 一,关系型数据库和非关系型数据可区别1.关系型数据库2.非关系型数据库3.区别3.1存储方式3.2扩展方式3.2事务性的支持 二,非关系型数据为什么产生三,Redis1.Redis是什么2.Redis优点3.Redis适用范围4. Redis 快的原因4.1 基于内存运行…

1-4微信小程序基础

模板配置 🌮🌮目标 1.能够使用WXML模板语法渲染页面结构2.能够使用WXSS样式渲染标签样式3.能够使用app.json对小程序进行全局配置4.能够使用page.json对小程序页面进行个性化配置5.如何发起网络数据请求 数据绑定的基本原则 在data中定义数据在WXML中…

(论文解读)Visual-Language Prompt Tuning with Knowledge-guided Context Optimization

Comment: accepted by CVPR2023 基于知识引导上下文优化的视觉语言提示学习 摘要 提示调优是利用任务相关的可学习标记将预训练的视觉语言模型(VLM)适应下游任务的有效方法。基于CoOp的代表性的工作将可学习的文本token与类别token相结合,…

Linux环境使用Git同步教程

📖 前言:由于CentOS 7已于2024年06月30日停止维护,为了避免操作系统停止维护带来的影响,我们将把系统更换为Ubuntu并迁移数据,在此之前简要的学习Git的上传下载操作。 目录 🕒 1. 连接🕘 1.1 配…

Effective C++笔记之二十二:C++临时变量的析构

先来看段代码 #include <iostream> #include <string>std::string myBlog() {return "https://blog.csdn.net/caoshangpa"; }int main() {const char *p myBlog().c_str();std::cout << p << std::endl;return 0; } 预期输出&#xff1a;…

Netty笔记06-组件ByteBuf

文章目录 概述ByteBuf 的特点ByteBuf的组成ByteBuf 的生命周期 ByteBuf 相关api1. ByteBuf 的创建2. 直接内存 vs 堆内存3. 池化 vs 非池化4. ByteBuf写入代码示例 5. ByteBuffer扩容6. ByteBuf 读取7. retain() & release()TailContext 释放未处理消息逻辑HeadContext 8. …

【新片场-注册安全分析报告-无验证方式导致安全隐患】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 1. 暴力破解密码&#xff0c;造成用户信息泄露 2. 短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉 3. 带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造…

面试爱考 | 设计模式

一、概述二、创建型 1. 单例&#xff08;Singleton&#xff09; IntentClass DiagramImplementationExamplesJDK 2. 简单工厂&#xff08;Simple Factory&#xff09; IntentClass DiagramImplementation 3. 工厂方法&#xff08;Factory Method&#xff09; IntentClass Diagr…

饿了么基于Flink+Paimon+StarRocks的实时湖仓探索

摘要&#xff1a;本文整理自饿了么大数据架构师、Apache Flink Contributor 王沛斌老师在8月3日 Streaming Lakehouse Meetup Online&#xff08;Paimon x StarRocks&#xff0c;共话实时湖仓架构&#xff09;上的分享。主要分为以下三个内容&#xff1a; 饿了么实时数仓演进之…