机器学习 - multi-class 数据集训练 (含代码)

news2025/1/15 8:14:48

直接上代码

# Multi-class dataset

import numpy as np
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K, D))
y = np.zeros(N*K, dtype='uint8')
for j in range(K):
  ix = range(N*j, N*(j+1))
  r = np.linspace(0.0, 1, N)  # radius
  t = np.linspace(j*4, (j+1)*4, N) + np.random.randn(N)*0.2
  X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  y[ix] = j
plt.scatter(X[:, 0], X[:, 1], c = y, s = 40, cmap=plt.cm.RdYlBu)
plt.show()

# Turn data into tensors
X = torch.from_numpy(X).type(torch.float)
y = torch.from_numpy(y).type(torch.LongTensor)

# Create train and test splits
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = RANDOM_SEED)

#!pip -q install torchmetrics
from torchmetrics import Accuracy
acc_fn = Accuracy(task = "multiclass", num_classes = 3).to("cpu")

class SpiralModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(in_features = 2, out_features = 10)
    self.linear2 = nn.Linear(in_features = 10, out_features = 10)
    self.linear3 = nn.Linear(in_features = 10, out_features = 3)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.linear3(self.relu(self.linear2(self.relu(self.linear1(x)))))

model_1 = SpiralModel().to("cpu")

# Setup data to be device agnostic
X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")

# Setup Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_1.parameters(),
                             lr = 0.02)

# Build a training loop for the model
epochs = 1000

# Loop over data
for epoch in range(epochs):
  ## Training
  model_1.train()

  y_logits = model_1(X_train)
  y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)

  loss = loss_fn(y_logits, y_train)
  acc = acc_fn(y_pred, y_train)

  optimizer.zero_grad()

  loss.backward()

  optimizer.step()

  ## Training
  model_1.eval()
  with torch.inference_mode():
    test_logits = model_1(X_test)
    test_pred = torch.softmax(test_logits, dim=1).argmax(dim=1)

    test_loss = loss_fn(test_logits, y_test)
    test_acc = acc_fn(test_pred, y_test)
  
  if epoch % 100 == 0:
    print(f"Epoch: {epoch} | Loss: {loss:.2f} Acc: {acc:.2f} | Test loss: {test_loss:.2f} Test acc: {test_acc:.2f}")

# Plot decision boundaries for training and test sets 
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model_1, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model_1, X_test, y_test)

结果如下

Epoch: 0 | Loss: 1.13 Acc: 0.32 | Test loss: 1.12 Test acc: 0.37
Epoch: 100 | Loss: 0.09 Acc: 0.98 | Test loss: 0.06 Test acc: 1.00
Epoch: 200 | Loss: 0.04 Acc: 0.99 | Test loss: 0.01 Test acc: 1.00
Epoch: 300 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 400 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 500 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 600 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 700 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 800 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 900 | Loss: 0.01 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00

结果1
结果2

给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1572921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

cJSON(API的详细使用教程)

我们今天来学习一般嵌入式的必备库,JSON库 1,json和cJSON 那什么是JSON什么是cJSON,他们之间有什么样的关联呢,让我们一起来探究一下吧。 JSON(JavaScript Object Notation)是一种轻量级的数据交换格式&…

tomcat 结构目录

bin 启动,关闭和其他脚本。这些 .sh文件(对于Unix系统)是这些.bat文件的功能副本(对于Windows系统)。由于Win32命令行缺少某些功能,因此此处包含一些其他文件。比如说:windows下启动tomcat用的是…

物理层习题及其相关知识(谁看谁不迷糊呢)

1. 对于带宽为50k Hz的信道,若有4种不同的物理状态来表示数据,信噪比为20dB 。(1) 按奈奎斯特定理,信道的最大传输数据速率是多少?(2) 按香农定理,信道的最大传输数据速度…

JAVAEE之Spring Boot日志

1. 日志概述 1.1 学习日志的原因 ⽇志对我们来说并不陌生, 从JavaSE部分, 我们就在使用 System.out.print 来打印日志了. 通过打印日志来发现和定位问题, 或者根据日志来分析程序的运行过程. 在Spring的学习中, 也经常根据控制台的日志来分析和定位问题. 随着项⽬的复杂…

记录Linux系统中vim同时开多个窗口编辑文件

在使用Linux进行文本编辑的时候,通常使用vim编辑器编辑文件,当然啦,vim也可以创建文件,如果只是一个一个创建,只需要vim创建即可,但是如何一次性打开多个窗口编辑呢? 目录 1、目标:…

微信小程序uniapp+vue.js旅游攻略系统9krxx

实现了一个完整的旅游攻略小程序系统,其中主要有用户模块、用户表模块、token表模块、收藏表模块、视频信息模块、视频类型模块、景点资讯模块、门票购买模块、旅游攻略模块、景点信息模块、论坛表模块、视频信息评论表模块、旅游攻略评论表模块、景点信息评论表模块…

python 02字符串

字符串可能是用到最多的数据类型了,所有标准序列操作(索引、切片、乘法、成员资格检查、长度、最小值和最大值)都适用于字符串 但别忘了字符串是不可变的,因此所有的元素赋值和切片赋值都是非法的。 1.居中效果 默认为空格 可…

搭建电商购物独立站抓取主流电商产品数据的方法:工具+电商数据采集API接口

分享一个抓取数据产品的方法,也是别人给我说的。 想做一个联盟产品相关的网站,然后需要采集电商网站的产品。咨询大佬告诉我,大量级电商商品数据的采集可以接入专业的电商数据采集API接口,也可以用webscrsper,于是乎就…

【Linux】环境基础开发工具使用——vim使用

Linux 软件包管理器 yum 什么是软件包 1.在 Linux 下安装软件 , 一个通常的办法是下载到程序的源代码 , 并进行编译 , 得到可执行程序 . 2.但是这样太麻烦了 , 于是有些人把一些常用的软件提前编译好 , 做成软件包 ( 可以理解成 windows 上的安装程序) 放在一个服务器…

LangChain - Retrieval

LangChain - Retrieval 文章目录 LangChain - Retrieval文件装载机 Document loaders文本分割 Text Splitting文本嵌入模型 Text embedding models向量存储Retrievers索引 APIClassesFunctions 官方文档:https://python.langchain.com/docs/modules/data_connection…

IJKPLAYER源码分析-mediacodec硬解

前言 近期腾出了点时间,拟对IJKPLAYER做更完整的源码分析,并对关键实现细节,作为技术笔记,记录下来。包括Android端硬解码/AudioTrack/OpenSL播放,以及iOS端硬解码/AudioUnit播放,以及OpenGL渲染和Android/…

文件服务器之二:SAMBA服务器

文章目录 什么是SAMBASAMBA的发展历史与名称的由来SAMBA常见的应用 SAMBA服务器基础配置配置共享资源Windows挂载共享Linux挂载共享 什么是SAMBA 下图来自百度百科 SAMBA的发展历史与名称的由来 Samba是一款开源的文件共享软件,它基于SMB(Server Messa…

使用Element Plus

1. 官网安装 安装 | Element Plus (gitee.io) 安装: npm install element-plus --save 在main.ts中全局注册ElementPlus并使用 //加入element-plus import ElementPlus from element-plus; //加入element-plus样式 import element-plus/dist/index.css; import…

Day106:代码审计-PHP原生开发篇文件安全上传监控功能定位关键搜索1day挖掘

目录 emlog-文件上传&文件删除 emlog-模板文件上传 emlog-插件文件上传 emlog-任意文件删除 通达OA-文件上传&文件包含 知识点: PHP审计-原生开发-文件上传&文件删除-Emlog PHP审计-原生开发-文件上传&文件包含-通达OA emlog-文件上传&文件…

Unknown redis exception; event execu tor terminated;解决

最近查看服务器日记是不是报发现有台服务器报错: rocessing failed; nested exception is org.springframework.data.redis.RedisSystemException: Unknown redis exception; nested exception is java.util.concurrent.RejectedExecutionException: event execu …

Stale Diffusion、Drag Your Noise、PhysReaction、CityGaussian

本文首发于公众号:机器感知 Stale Diffusion、Drag Your Noise、PhysReaction、CityGaussian Drag Your Noise: Interactive Point-based Editing via Diffusion Semantic Propagation Point-based interactive editing serves as an essential tool to compleme…

Python实现特征模态分解(FMD)

大家好,我是带我去滑雪! 特征模态分解(Feature Mode Decomposition,FMD)是一种信号处理技术,用于从数据中提取特征,并将其表示为一组特定的模态成分。与其他分解方法类似,如小波变换…

RUST语言值所有权之内存复制与移动

1.RUST中每个值都有一个所有者,每次只能有一个所有者 String::from函数会为字符串hello分配一块内存 内存示例如下: 在内存分配前调用s1正常输出 在分配s1给s2后调用报错 因为s1分配给s2后,s1的指向自动失效 s1被move到s2 s1自动释放 字符串克隆使用

Oracle 中 where 和 on 的区别

1.Oracle 中 where 和 on 的区别 on:会先根据on后面的条件进行筛选,条件为真时返回该行,由于on的优先级高于left join,所以left join关键字会把左表中没有匹配的所有行也都返回,然后生成临时表返回,执行优先级高于…

Python 基于列表实现的通讯录管理系统(有完整源码)

目录 通讯录管理系统 PersonInformation类 ContactList类 menu函数 main函数 程序的运行流程 完整代码 运行示例 通讯录管理系统 这是一个基于文本的界面程序,用户可以通过命令行与之交互,它使用了CSV文件来存储和读取联系人信息,这…