Pytorch-Transformer轴承故障一维信号分类(三)

news2024/10/7 10:22:20

目录

前言

1 数据集制作与加载

1.1 导入数据

第一步,导入十分类数据

第二步,读取MAT文件驱动端数据

第三步,制作数据集

第四步,制作训练集和标签

1.2 数据加载,训练数据、测试数据分组,数据分batch

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型,采用Transformer架构中的编码器:

2.2 定义模型参数

2.3 模型结构

3 Transformer模型训练与评估

3.1 模型训练

3.2 模型评估


往期精彩内容:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

Python轴承故障诊断 (一)短时傅里叶变换STFT

Python轴承故障诊断 (二)连续小波变换CWT

Python轴承故障诊断 (三)经验模态分解EMD

Python轴承故障诊断 (四)基于EMD-CNN的故障分类

Python轴承故障诊断 (五)基于EMD-LSTM的故障分类

Pytorch-LSTM轴承故障一维信号分类(一)

Pytorch-CNN轴承故障一维信号分类(二)

前言

本文基于凯斯西储大学(CWRU)轴承数据,先经过数据预处理进行数据集的制作和加载,最后通过Pytorch实现Transformer模型对故障数据的分类,并介绍Transformer模型的超参数。凯斯西储大学轴承数据的详细介绍可以参考下文:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

1 数据集制作与加载

1.1 导入数据

参考之前的文章,进行故障10分类的预处理,凯斯西储大学轴承数据10分类数据集:

第一步,导入十分类数据

import numpy as np
import pandas as pd
from scipy.io import loadmat

file_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']

for file in file_names:
    # 读取MAT文件
    data = loadmat(f'matfiles\\{file}')
    print(list(data.keys()))

第二步,读取MAT文件驱动端数据

# 采用驱动端数据
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time',
                'X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):
    # 读取MAT文件
    data = loadmat(f'matfiles\\{file_names[index]}')
    dataList = data[data_columns[index]].reshape(-1)
    data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作数据集

train_set、val_set、test_set 均为按照7:2:1划分训练集、验证集、测试集,最后保存数据

第四步,制作训练集和标签

# 制作数据集和标签
import torch

# 这些转换是为了将数据和标签从Pandas数据结构转换为PyTorch可以处理的张量,
# 以便在神经网络中进行训练和预测。

def make_data_labels(dataframe):
    '''
        参数 dataframe: 数据框
        返回 x_data: 数据集     torch.tensor
            y_label: 对应标签值  torch.tensor
    '''
    # 信号值
    x_data = dataframe.iloc[:,0:-1]
    # 标签值
    y_label = dataframe.iloc[:,-1]
    x_data = torch.tensor(x_data.values).float()
    y_label = torch.tensor(y_label.values.astype('int64')) # 指定了这些张量的数据类型为64位整数,通常用于分类任务的类别标签
    return x_data, y_label

# 加载数据
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')

# 制作标签
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存数据
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 数据加载,训练数据、测试数据分组,数据分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载数据集
def dataloader(batch_size, workers=2):
    # 训练集
    train_xdata = load('trainX_1024_10c')
    train_ylabel = load('trainY_1024_10c')
    # 验证集
    val_xdata = load('valX_1024_10c')
    val_ylabel = load('valY_1024_10c')
    # 测试集
    test_xdata = load('testX_1024_10c')
    test_ylabel = load('testY_1024_10c')

    # 加载数据
    train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),
                                   batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),
                                 batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),
                                  batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    return train_loader, val_loader, test_loader

batch_size = 32
# 加载数据
train_loader, val_loader, test_loader = dataloader(batch_size)

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型,采用Transformer架构中的编码器:

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为 32 * 32, 就使输入序列的长度降下来了

2.2 定义模型参数

# 模型参数
input_dim = 32 # 输入维度
hidden_dim = 512  # 注意力维度
output_dim  = 10  # 输出维度
num_layers = 4   # 编码器层数
num_heads = 8    # 多头注意力头数
batch_size = 32
# 模型
model = TransformerModel(input_dim, output_dim, hidden_dim, num_layers, num_heads, batch_size)  
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)  # 优化器

2.3 模型结构

3 Transformer模型训练与评估

3.1 模型训练

训练结果

100个epoch,准确率将近90%,Transformer模型分类效果良好,参数过拟合了,适当调整模型参数,降低模型复杂度,还可以进一步提高分类准确率。

注意调整参数:

  • 可以适当增加 Transforme编码器层数 和隐藏层的维度,微调学习率;

  • 调整多头注意力的头数,增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

3.2 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F

# 加载模型
model =torch.load('best_model_transformer.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))

# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():
    correct_test = 0
    test_loss = 0
    for test_data, test_label in test_loader:
        test_data, test_label = test_data.to(device), test_label.to(device)
        test_output = model(test_data)
        probabilities = F.softmax(test_output, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        correct_test += (predicted_labels == test_label).sum().item()
        loss = loss_function(test_output, test_label)
        test_loss += loss.item()

test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')

Test Accuracy: 0.9570  Test Loss: 0.12100271

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1305778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

科技提升安全,基于YOLOv6开发构建商超扶梯场景下行人安全行为姿态检测识别系统

在商超等人流量较为密集的场景下经常会报道出现一些行人在扶梯上摔倒、受伤等问题,随着AI技术的快速发展与不断普及,越来越多的商超、地铁等场景开始加装专用的安全检测预警系统,核心工作原理即使AI模型与摄像头图像视频流的实时计算&#xf…

基于Qt的蓝牙Bluetooth在ubuntu实现模拟

​# 前言 Qt 官方提供了蓝牙的相关类和 API 函数,也提供了相关的例程给我们参考。笔者根据 Qt官方的例程编写出适合我们 Ubuntu 和 gec6818开发板的例程。注意 Windows 上不能使用 Qt 的蓝牙例程,因为底层需要有 BlueZ协议栈,而 Windows 没有。Windows 可能需要去移植。笔者…

交友系统:打造独具魅力的社交平台!APP小程序H5三端源码交付,支持二开!

随着社交媒体的兴起,交友系统成为了现代社会不可或缺的一部分。人们希望通过网络结识新朋友,拓展社交圈,寻找志同道合的伙伴,甚至找到自己的爱情。本文将为您介绍交友系统的定义、功能以及如何打造一个独具魅力的社交平台。 一个成…

鸿蒙开发 - ohpm安装第三方库

前端开发难免使用第三方库,鸿蒙亦是如此,在使用 DevEco Studio 开发工具时,如何引入第三方库呢?操作步骤如下,假设你使用的是MacOS,假设你已经创建了了一个项目: 一、配置 HTTP Proxy 在打开了…

鸿蒙开发之状态管理@State

1、视图数据双向绑定 鸿蒙开发采用的声明式UI,利用状态驱动UI的更新。其中State被称作装饰器,是一种状态管理的方式。 状态:指的是被装饰器装饰的驱动视图更新的数据。 视图:是指用户看到的UI渲染出来的界面。 之所以成为双向…

数据采集网关:工业数据采集上云

数据采集网关,以其高效、便捷的特点,成为了现代工业物联网数据采集处理的重要工具。它是连接不同数据源和数据接收设备的桥梁,将各种形式和格式的数据快速、安全地汇聚到一起。通过数据采集网关,企业可以轻松实现数据的整合、转换…

指针浅谈(三)

在指针浅谈(二)http://t.csdnimg.cn/SKAkD中我们讲到了const修饰指针、指针运算、野指针、assert断言和传址调用的内容,今天我们继续学习有关数组名、指针访问数组、一维数组传参的本质相关的内容,内容比较深入,如果觉得哪里讲解的不行&#…

【Apollo】ubuntu20.04源码安装apollo8.0

官方源码安装教程 https://blog.csdn.net/weixin_45929038/article/details/120113008 安装NVIDIA GPU驱动 Apollo 8.0 的一些模块的编译和运行需要依赖 NVIDIA GPU 环境(例如感知模块),如果有编译和运行这类模块的需求,则需要安…

时间序列预测 — BiLSTM实现多变量多步光伏预测(Tensorflow)

目录 1 数据处理 1.1 导入库文件 1.2 导入数据集 1.3 缺失值分析 2 构造训练数据 3 模型训练 3.1 BiLSTM网络 3.2 模型训练 4 模型预测 1 数据处理 1.1 导入库文件 import time import datetime import pandas as pd import numpy as np import matplotlib.pyplot…

已经写完的论文怎么降低查重率 papergpt

大家好,今天来聊聊已经写完的论文怎么降低查重率,希望能给大家提供一点参考。 以下是针对论文重复率高的情况,提供一些修改建议和技巧: 已经写完的论文怎么降低查重率 背景介绍 在学术界,论文的查重率是评价论文质量的…

QT----第三天,Visio stdio自定义封装控件

目录 第三天1 自定义控件封装 源码:CPP学习代码 第三天 1 自定义控件封装 新建一个QT widgetclass,同时生成ui,h,cpp文件 在smallWidget.ui里添加上你想要的控件并调试大小 回到mainwidget.ui,拖入一个widget(因为我们封装的也…

MES系统在制造企业数字化工厂中扮演着什么角色?

MES是制造执行系统(Manufacturing Execution System)的缩写。它是一种用于监控和管理制造过程的数字化管理系统,旨在优化生产流程、提高效率并确保产品质量。通过整合各种生产环节,MES系统为企业提供了更高效、更智能的生产管理方…

LangChain学习二:提示-实战(下半部分)

文章目录 上一节内容:LangChain学习二:提示-实战(上半部分)学习目标:提示词中的示例选择器和输出解释器学习内容一:示例选择器1.1 LangChain自定义示例选择器1.2 实现自定义示例选择器1.2.1实战&#xff1a…

【大数据】Doris 架构

Doris 架构 Doris 的架构很简洁,只设 FE(Frontend)、BE(Backend)两种角色、两个进程,不依赖于外部组件,方便部署和运维,FE、BE 都可线性扩展。 ✅ Frontend(FE&#xff0…

MySQL概述

数据库相关概念 名称全称简称数据库存储数据的仓库,数据是有组织的进行存储DataBase (DB)数据库管理系统操纵和管理数据库的大型软件。有关系型数据库(RDBMS)与非关系型数据库(NoSQL)两种DataBase Management System (DBMS)SQL操作关系型数据库的编程语言&#xff…

MySQL 8.x temp空间不足问题

目录 一、系统环境 二、问题报错 三、问题回顾 四、解决问题 一、系统环境 系统Ubuntu20.04 数据库版本MySQL 8.0.21 二、问题报错 在MySQL上执行一个大的SQL查询报错Error writing file /tmp/MYfd142 (OS errno 28 - No space left on device) Exception in thread …

记录 | linux安装Manim

linux 安装 Manim sudo apt update sudo apt install build-essential python3-dev libcairo2-dev libpango1.0-dev ffmpeg sudo apt install xdg-utilsconda create manim_py39 python3.9 conda activate manim_py39pip install manim安装好环境后来测试一个例程,…

GO闭包实现原理(汇编级讲解)

go语言闭包实现原理(汇编层解析) 1.起因 今天开始学习go语言,在学到go闭包时候,原本以为go闭包的实现方式就是类似于如下cpp lambda value通过值传递,mutable修饰可以让value可以修改,但是地址不可能一样value通过引用传递,但是在其他地方调用时,这个value局部变量早就释放,…

freemarker+Aspose.word实现模板生成word并转成pdf

需求:动态生成pdf指定模板 实现途径:通过freemarker模板,导出word文档,同时可将word转为pdf。 技术选择思路 思路一:直接导出pdf 使用itext模板导出pdf 适用范围 业务生成的 pdf 是具有固定格式或者模板的文字及其…

Spring Boot 3 整合 Mybatis-Plus 动态数据源实现多数据源切换

🚀 作者主页: 有来技术 🔥 开源项目: youlai-mall 🍃 vue3-element-admin 🍃 youlai-boot 🌺 仓库主页: Gitee 💫 Github 💫 GitCode 💖 欢迎点赞…