PyTorch demo——基于MLP的鸢尾花分类

news2024/9/21 22:36:18

系统框架

在这里插入图片描述

1. 数据集加载

  继承torch.utils.data.Dataset类,重写__getitem__和__len__方法,并在__getitem__中预处理数据。

# load.py
import torch


class IrisDataset(torch.utils.data.Dataset):
    def __init__(self, data_file, iris_class):
        super(IrisDataset, self).__init__()

        self.iris_class = iris_class

        self.all_data = []
        with open(data_file, 'r') as f:
            lines = f.readlines()
            lines = [line.rstrip() for line in lines]
            for l in lines:
                l = l.split(',')
                vec = [float(i) for i in l[:-1]]
                label = self.iris_class[str(l[-1])]
                self.all_data.append([vec, label])


    def __getitem__(self, item):
        fea, label = self.all_data[item]
        fea, label = torch.tensor(fea, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
		# No data augmentation
		
        return fea, label


    def __len__(self):

        return len(self.all_data)


if __name__ == "__main__":
    import config
    dataset = IrisDataset("iris/train", config.iris_class)
    print(dataset.__getitem__(0))

2. 网络模型——MLP

在这里插入图片描述

# net.py
import torch
import torch.nn as nn


class Net(torch.nn.Module):
    def __init__(self, input_dim=4, num_class=3):
        super(Net, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_class),
            nn.Softmax()
        )

    def forward(self, x):

        return self.fc(x)


if __name__ == "__main__":
    net = Net()
    print(net)

    x = torch.randn(2, 4)
    print(net(x).shape)

3. 配置文件——网络参数、训练参数整理

# config.py
import warnings
warnings.filterwarnings('ignore')

"""dataset"""
iris_class = {
    "Iris-setosa": 0,
    "Iris-versicolor": 1,
    "Iris-virginica": 2
}

"""net args"""
input_dim = 4
num_class = 3

"""train & valid"""
train_data = 'iris/train'
valid_data = 'iris/valid'
batch_size = 10
nworks = 1
max_epoch = 200
lr = 1e-3
factor = 0.9

""" test """
test_data = "iris/test"
pre_model = "pth/model_100.pth"

4. 训练

# train.py
import torch, os, tqdm
from torch.utils.data import DataLoader

import load, net, config
import matplotlib.pyplot as plt


def train():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str(0))   
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset for train and eval
    train_dataset = load.IrisDataset(config.train_data, config.iris_class)
    train_batchs = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.nworks, pin_memory=True)
    valid_dataset = load.IrisDataset(config.valid_data, config.iris_class)
    valid_batchs = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.nworks, pin_memory=True)

    model = net.Net(config.input_dim, config.num_class)
    model = model.to(DEVICE)

    loss_criterion = torch.nn.CrossEntropyLoss()    
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

    os.makedirs("pth", exist_ok=True)

    plt.ion()
    train_loss, valid_loss, valid_acc = [], [], []

    for epoch in tqdm.tqdm(range(1, config.max_epoch+1)):

        optimizer.param_groups[0]['lr'] = config.lr * ((1 - (epoch-1)/ config.max_epoch)**config.factor)

        """ train """
        model.train()
        total_loss=0
        for batch, (fea, target) in enumerate(train_batchs):
            fea, target = fea.to(DEVICE), torch.nn.functional.one_hot(target, 3).float().to(DEVICE)
            pred = model(fea)

            loss = loss_criterion(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        train_epoch_loss = total_loss / len(train_dataset)*config.batch_size
        # print("epoch",epoch,"loss:", train_epoch_loss)
        torch.save(model.state_dict(), os.path.join("pth", 'model_' + str(epoch) + '.pth'))

        """ valid """
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            total_loss = 0
            for fea, labels in valid_batchs:
                labels = labels.to(DEVICE)
                fea, target = fea.to(DEVICE), torch.nn.functional.one_hot(labels, 3).float().to(DEVICE)

                pred = model(fea)

                loss = loss_criterion(pred, target)
                total_loss += loss.item()

                _, predicted = torch.max(pred.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum()

            valid_epoch_loss = total_loss / len(valid_dataset) * config.batch_size
            # print('Accuracy test set: %d%%' % (100 * (correct / total)))

        train_loss.append(train_epoch_loss)
        valid_loss.append(valid_epoch_loss)
        valid_acc.append(correct.cpu() / total)

        plt.clf()
        plt.plot(train_loss, color='black', label="train loss")
        plt.plot(valid_loss, color='red', label="valid loss")
        plt.plot(valid_acc, color='green', label="valid acc")
        plt.grid()
        plt.legend()
        plt.savefig("train.jpg")

    plt.ioff()
    plt.close()


if __name__ == '__main__':

    train()

训练过程可视化
在这里插入图片描述

5.测试

# test.py
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report

import net, load, config


def test():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str(0))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset
    test_dataset = load.IrisDataset(config.test_data, config.iris_class)
    test_batchs = DataLoader(test_dataset, batch_size=10, shuffle=False,
                             num_workers=0, pin_memory=True)

    # model
    model = net.Net(config.input_dim, config.num_class)
    model.load_state_dict(torch.load(config.pre_model, map_location='cpu'), strict=False)
    model = model.to(DEVICE)

    # test
    model.eval()
    with torch.no_grad():
        preds, labels = [], []
        for i, (fea, label) in enumerate(test_batchs):

            pred = model(fea.to(DEVICE))
            _, predicted = torch.max(pred.data, dim=1)

            preds.append(predicted)
            labels.append(label)

    # report
    preds = torch.stack(preds, dim=0).view(-1).cpu().numpy()
    labels = torch.stack(labels, dim=0).view(-1).numpy()

    report = classification_report(labels, preds, target_names = config.iris_class.keys())
    print(report)


if __name__ == '__main__':

    test()

测试集结果
在这里插入图片描述

6.文件结构

在这里插入图片描述

6.1 requirements.txt

matplotlib==3.7.2
scikit_learn==1.3.2
torch==2.0.0+cu118
tqdm==4.65.2

6.2附已划分的数据集

训练集——iris/train
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
验证集——iris/valid
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
测试集——iris/test
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2134303.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么护眼台灯性价比高又好用?良心推荐五款性价比高的护眼台灯

在家里,灯具是属于离不开的家具,每个大大小小的地方都需要的照亮,所以一盏好灯是必不可少的,每个发挥着作用。而护眼台灯就起了一个保护眼睛的作用,可以保护我们在学习,阅读的时候提供一个合适的光线环境&a…

Elasticsearch知识点整理

数据分类 非结构化数据 全文数据。不定长或无固定格式 报错xml,HTML,Word结构化数据 行数据,由二维表结构来逻辑表达和实现的数据 非结构化数据 对于非结构化的数据 搜索主要有两种方法 顺序扫描全文检索 顺序扫描 一般不建议这么做。例如给你一张报纸&…

PHP一键寄送尽在掌中快递寄件小程序

一键寄送尽在掌中 —— 快递寄件小程序全体验 🌟 开篇:告别繁琐,拥抱便捷新纪元 还在为寄快递而烦恼吗?排队等待、填写繁琐的单据、等待快递员上门...这些统统成为过去式!“一键寄送尽在掌中快递寄件小程序”。它就像…

红光一字激光器在工业中的性能指标怎样

红光一字激光器作为现代工业中不可或缺的重要设备,以其独特的性能和广泛的应用场景,成为众多行业的首选工具。本文就跟大家详细探讨红光一字激光器在工业中的性能指标,以及这些指标如何影响其在实际应用中的表现。 光束质量 红光一字激光器以…

气膜体育馆:为学校打造智能化运动空间—轻空间

随着教育体制的逐步升级,学校在提升学生综合素质方面的需求日益增长,特别是在体育场地方面。气膜体育馆作为一种新型的运动空间形式,正在迅速成为学校体育设施的优选方案。凭借其快速搭建、节能环保等优势,气膜馆在全国各地的校园…

STM32 的 RTC(实时时钟)详解

目录 一、引言 二、RTC 概述 三、RTC 的工作原理 1.时钟源 2.计数器 3.闹钟功能 4.备份寄存器 四、RTC 寄存器 1.RTC_TR(Time Register,时间寄存器) 2.RTC_DR(Date Register,日期寄存器) 3.RTC_S…

R语言统计分析——功效分析(比例、卡方检验)

参考资料:R语言实战【第2版】 1、比例检验 当比较两个比例时,可使用pwr.2p.test()函数进行功效分析。格式为: pwr.2p.test(h, n, sig.level, power, alternative) 其中,h是效应值,n是各相同的样本量。效应值h的定义如…

性能测试-jmeter提取器(十三)

一、jmeter的常用关联 正则表达式提取器xpath提取器json提取器 二、正则表达式提取器 注&#xff1a;&#xff08;正则表达式的变量与引用的变量的区别&#xff1a;引用变量多加了"_1"后缀&#xff09; 需求&#xff1a;将www.itcast.cn网页时&#xff0c;<ti…

【JAVA开源】基于Vue和SpringBoot的校园管理系统

本文项目编号 T 026 &#xff0c;文末自助获取源码 \color{red}{T026&#xff0c;文末自助获取源码} T026&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 管…

每日一练:游游的u

1.题目 2.代码 #include <iostream> using namespace std;int main() {int q 0;cin >> q;int a,b,c;while(q--){cin >> a >> b >> c;int you min(a,min(b,c)) * 2;int ooo max(b-(you/2)-1,0);cout << (you ooo) << endl;}retu…

【计算机毕设-大数据方向】基于Hadoop的社交媒体数据分析可视化系统的设计与实现

&#x1f497;博主介绍&#xff1a;✌全平台粉丝5W,高级大厂开发程序员&#x1f603;&#xff0c;博客之星、掘金/知乎/华为云/阿里云等平台优质作者。 【源码获取】关注并且私信我 【联系方式】&#x1f447;&#x1f447;&#x1f447;最下边&#x1f447;&#x1f447;&…

[ComfyUI]Flux:写真新篇章!字节PuLID率先开启一致性风格迁移,无损画手和优质画面保持

前言 Flux&#xff1a;PuLID率先开启F1写真新篇章 所有的AI设计工具&#xff0c;模型和插件&#xff0c;都已经整理好了&#xff0c;&#x1f447;获取~ Flux PuLID简介 在Flux出来后短时间内&#xff0c;社区生态反响和发展足够的迅猛快速。至今为止&#xff0c;社区LORA模…

力扣每日一题:236.二叉树的最近公共祖先

题目 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节点也可以是它…

<<编码>> 第 11 章 逻辑门电路(Gates)--猫咪选择电路 示例电路

使用门电路的猫咪选择电路 info::操作说明 鼠标单击开关切换开合状态 primary::在线交互操作链接 https://cc.xiaogd.net/?startCircuitLinkhttps://book.xiaogd.net/code-hlchs-examples/assets/circuit/code-hlchs-ch11-16-cat-circuit-with-gate.txt 集成的猫咪选择电路 in…

html+css+js网页设计 旅游 厦门旅游网11个页面

htmlcssjs网页设计 旅游 厦门旅游网11个页面 网页作品代码简单&#xff0c;可使用任意HTML辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码 1&am…

WPF 手撸插件 八 依赖注入

本文内容大量参考了&#xff1a;https://www.cnblogs.com/Chary/p/11351457.html 而且这篇文章总结的非常好。 1、注意想使用Autofac&#xff0c;Autofac是一个轻量级、‌高性能的依赖注入&#xff08;‌DI&#xff09;‌框架&#xff0c;‌主要用于.NET应用程序的组件解耦和…

Halcon 深度学习 分类预处理

文章目录 read_dl_dataset_classification 产生一个深度学习数据集算子split_dl_dataset 将样本分为训练、验证和测试子集create_dl_preprocess_param 使用预处理参数创建字典preprocess_dl_dataset 预处理DLDataset中声明的整个数据集write_dict 写入字典文件find_dl_samples …

[网络]TCP/IP五层协议之应用层,传输层(1)

文章目录 一. 应用层二. 传输层端口号传输层的协议UDPTCPTCP报头TCP协议的核心机制 一. 应用层 应用层是和应用程序直接相关, 和程序猿打交道最多的一层 应用层协议, 里面描述的内容, 就是你写的程序, 通过网络具体按照啥样的形式来传输数据 不同的应用程序, 就可以用不同的应…

PHP 使用Spreadsheet写excel缓存导致内存不断增加

这里写自定义目录标题 问题描述问题解决 问题描述 新增了 Spreadsheet 用于写 excle 文件。 从网上查找一些实例后&#xff0c;封装成 createExcelFormData 函数如下&#xff1a; /*** brief 按照指定的键&#xff0c;将 array2(关联数组) 合并到 array1(关…

【C#】VS插件

翻译 目前推荐较多的 可以单词发言&#xff0c;目前还在开发阶段 TranslateIntoChinese - Visual Studio Marketplace 下载量最高的(推荐) Visual-Studio-Translator - Visual Studio Marketplace 支持翻译的版本较多&#xff0c;在 Visual Studio 代码编辑器中通过 Googl…