鸢尾花数据集分类(PyTorch实现)

news2025/1/6 20:28:23

一、数据集介绍

在这里插入图片描述
Data Set Information:
This is perhaps the best known database to be found in the pattern recognition literature. Fisher’s paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. One class is linearly separable from the other 2; the latter are NOT linearly separable from each other.

Attribute Information:

  1. sepal length in cm
  2. sepal width in cm
  3. petal length in cm
  4. petal width in cm
  5. class:
    – Iris Setosa
    – Iris Versicolour
    – Iris Virginica

二、使用贝叶斯分类

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets

# 读取数据
iris = datasets.load_iris()
iris_data = iris.data
iris_target = iris.target

# 划分训练集和测试集
data_size = iris_data.shape[0]
train_data = iris_data[: int(data_size * 0.8)]
train_target = iris_target[: int(data_size * 0.8)]
test_data = iris_data[int(data_size * 0.8):]
test_target = iris_target[int(data_size * 0.8):]


# 定义模型
class BayesianModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(BayesianModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = self.fc(x)
        return x


# 实例化模型
model = BayesianModel(input_size=4, output_size=3)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.03)

# 训练模型
for epoch in range(10000):
    optimizer.zero_grad()
    outputs = model(torch.tensor(train_data, dtype=torch.float32))
    loss = criterion(outputs, torch.tensor(train_target, dtype=torch.long))
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        print("Epoch: %d, Loss: %.4f" % (epoch, loss.item()))

# 评估模型
with torch.no_grad():
    outputs = model(torch.tensor(test_data, dtype=torch.float32))
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == torch.tensor(test_target, dtype=torch.long)).sum().item() / len(test_target)
    print("Accuracy: %.2f %%" % (accuracy * 100))

在这里插入图片描述

三、使用支持向量机分类

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
from sklearn.model_selection import train_test_split

# 加载数据集
iris = datasets.load_iris()
X = iris["data"]
y = iris["target"]

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# 转换为PyTorch tensor
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)


# 定义SVM模型
class SVM(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)


# 创建模型实例
input_dim = X_train.shape[1]
num_classes = 3
model = SVM(input_dim, num_classes)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")

# 评估模型
with torch.no_grad():
    outputs = model(X_test)
    _, pred = torch.max(outputs, 1)
    correct = (pred == y_test).sum().item()
    accuracy = correct / y_test.shape[0]
    print("Accuracy: %.2f %%" % (accuracy * 100))

在这里插入图片描述

四、使用神经网络分类

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn import datasets
import numpy as np

# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris["data"].astype(np.float32)
y = iris["target"].astype(np.int64)

# 将数据分为训练集和测试集
train_ratio = 0.8
index = np.random.permutation(X.shape[0])
train_index = index[:int(X.shape[0] * train_ratio)]
test_index = index[int(X.shape[0] * train_ratio):]
X_train, y_train = X[train_index], y[train_index]
X_test, y_test = X[test_index], y[test_index]

# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    inputs = torch.from_numpy(X_train)
    labels = torch.from_numpy(y_train)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

# 评估模型
with torch.no_grad():
    inputs = torch.from_numpy(X_test)
    labels = torch.from_numpy(y_test)
    outputs = model(inputs)
    _, predictions = torch.max(outputs, 1)
    accuracy = (predictions == labels).float().mean()
    print("Accuracy: %.2f %%" % (accuracy.item() * 100))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/338808.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Android Studio]Android 数据存储-文件存储学习笔记-结合保存QQ账户与密码存储到指定文件中的演练

🟧🟨🟩🟦🟪 Android Debug🟧🟨🟩🟦🟪 Topic 发布安卓学习过程中遇到问题解决过程,希望我的解决方案可以对小伙伴们有帮助。 📋笔记目…

戴尔游匣G16电脑U盘安装系统操作教程分享

戴尔游匣G16电脑U盘安装系统操作教程分享。有用户在使用戴尔游匣G16电脑的时候遇到了系统问题,比如电脑蓝屏、自动关机重启、驱动不兼容等问题。遇到这些问题如果无法进行彻底解决,我们可以通过U盘重新安装系统的方法来解决,因为这些问题一般…

I.MX6ULL内核开发7:led字符设备驱动实验

目录 一、led字符设备驱动实验 二、驱动模块初始化 三、虚拟地址读写 四、自定义led的file_operation接口 五、拷贝数据 六、register_chrdev函数 七、 __register_chrdev函数 八、编译执行 一、led字符设备驱动实验 驱动模块内核模块(.ko)驱动接口(file_operations) …

Mysql 增删改查(一) —— 查询(条件查询where、分页limits、排序order by)

查询 select 可以认为是四个基本操作中使用最为频繁的操作,然而数据量比较大的时候,我们不可能查询所有内容,我们一般会搭配其他语句进行查询: 假如要查询某一个字段的内容,可以使用 where假如要查询前几条记录&#…

STM32----搭建Arduino开发环境

搭建Arduino开发环境前言一、Arduino软件1.软件下载2.软件安装3.软件操作二、Cortex官方内核三、烧录下载四、其他第三方内核1.Libmaple内核2.Steve改进的LibMaple 内核3.STMicroelectronics(ST)公司编写的内核总结前言 本章介绍搭建STM32搭建Arduino开发环境,包括…

leetcode470 用Rand7()实现Rand10()

力扣470 第一步:根据Rand7()函数制作一个可以随机等概率生成0和1的函数rand_0and1 调用Rand7()函数,随机等概率生成1,2,3,4,5,6,7 这时我们设置:生成1,2&a…

“深度学习”学习日记。卷积神经网络--用CNN的实现MINIST识别任务

2023.2.11 通过已经实现的卷积层和池化层,搭建CNN去实现MNIST数据集的识别任务; 一,简单CNN的网络构成: 代码需要在有网络的情况下运行,因为会下载MINIST数据集,运行后会生成params.pkl保留训练权重&…

【吉先生的Java全栈之路】

吉士先生Java全栈学习路线🧡第一阶段Java基础: 在第一阶段:我们要认真听讲,因为基础很重要!基础很重要!基础很重要!!! 重要的事情说三遍。在这里我们先学JavaSE路线;学完之后我们要去学第一个可视化组件编程《GUI》;然后写个《贪吃蛇》游戏耍…

微搭低代码从入门到精通05-变量定义

我们上一篇对应用编辑器有了一个整体的介绍。要想零基础开发小程序,就得从各种概念开始学起。 如果你是零基础学习开发,无论学习哪一门语言,第一个需要掌握的知识点就是变量。 那么什么是变量?变量其实就是存放数据的一个容器&a…

专题 | 防抖和节流

一 防抖:单位时间内,频繁触发事件,只执行最后一次 场景:搜索框搜索输入(利用定时器,每次触发先清掉以前的定时器,从新开始) 节流:单位时间内,频繁触发事件&…

Yii2模板:自定义头部脚部文件,去掉头部脚部文件

一、yii安装完成之后,运行结果如下图二、如何自定义头部脚部文件呢0、默认展示1、在类里定义,在整个类中生效2、在方法中定义,在当前方法中生效3、home模板介绍三、去掉头部脚部文件1、控制 $layout 的值2、把action中的render改为renderPart…

前端对于深拷贝和浅拷贝的应用和思考

浅拷贝 浅拷贝 : 浅拷贝是指对基本类型的值拷贝,以及对对象类型的地址拷贝。它是将数据中所有的数据引用下来,依旧指向同一个存放地址,拷贝之后的数据修改之后,也会影响到原数据的中的对象数据。最简单直接的浅拷贝就…

java ssm集装箱码头TOS系统调度模块的设计与实现

由于历史和经济体制的原因,国内码头物流企业依然保持大而全的经营模式。企业自己建码头、场地、经营集装箱运输车辆。不过近几年来随着经济改革的进一步深入和竞争的激烈,一些大型的码头物流企业逐步打破以前的经营模式,其中最明显的特征就是…

利用机器学习(mediapipe)进行人脸468点的3D坐标检测--视频实时检测

上期文章,我们分享了人脸468点的3D坐标检测的图片检测代码实现过程,我们我们介绍一下如何在实时视频中,进行人脸468点的坐标检测。 import cv2 import mediapipe as mp mp_drawing = mp.solutions.drawing_utils mp_face_mesh = mp.solutions.face_mesh face_mesh = mp_fac…

ubuntu 驱动更新后导致无法进入界面

**问题描述: **安装新ubuntu系统后未禁止驱动更新导致无法进入登录界面。 解决办法: 首先在进入BIOS中,修改设置以进行命令行操作,然后卸载已有的系统驱动,最后安装新的驱动即可。 开机按F11进入启动菜单栏&#xf…

【JavaScript 逆向】安居客滑块逆向分析

声明本文章中所有内容仅供学习交流,相关链接做了脱敏处理,若有侵权,请联系我立即删除!案例目标验证码:aHR0cHM6Ly93d3cuYW5qdWtlLmNvbS9jYXB0Y2hhLXZlcmlmeS8/Y2FsbGJhY2s9c2hpZWxkJmZyb209YW50aXNwYW0以上均做了脱敏处…

如何准备大学生电子设计竞赛

大学生电子设计竞赛难度中上,一般有好几个类型题目可以选择,参赛者可以根据自己团队的能力、优势去选择合适自己的题目,灵活自主空间较大。参赛的同学们可以在暑假好好学习相关内容,把往年的题目拿来练练手。这个比赛含金量还是有…

数据可视化,流程化处理pycharts-

本文直接进入可视化,输入讲解输入列表生成图片,关于pandas操作看这篇pandas matplotlib 导包后使用 import matplotlib.pyplot as plt饼图 使用 plt.figure 函数设置图片的大小为 15x15 使用 plt.pie 函数绘制饼图,并设置相关的参数&…

详细的从零部署ChatGPT

chatgpt产品机遇: 1. chatgpt 所带来的机遇: 下一代 AI 搜索引擎,解决目前搜索引擎结果多样复杂、需要人工判断准确定的问题;替代低端劳动岗位、释放部分脑力活动、即将变革多个行业 ; 2. chatgpt 我分析将带来多个新的工作岗位机…

【Opencv实战】想给图片去水印?这样操作,几百张图片1分钟无痕去水印,这款去水印神器终于被我找到啦~(超厉害的)

前言 🚀 作者 :“程序员梨子” 🚀 **文章简介 **:本篇文章主要是写了opencv的人脸检测、猫脸检测小程序。 🚀 **文章源码免费获取 : 为了感谢每一个关注我的小可爱💓每篇文章的项目源码都是无…