CRWU凯斯西储大学轴承数据,12k频率,十分类

news2024/9/21 12:43:33

在这里插入图片描述
CRWU凯斯西储大学轴承数据,12k频率,十分类。

from torch.utils.data import Dataset, DataLoader
from scipy.io import loadmat
import numpy as np
import os
from sklearn import preprocessing  # 0-1编码
from sklearn.model_selection import StratifiedShuffleSplit  # 随机划分,保证每一类比例相同
import torch
from torch import nn
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch.optim as optim

def prepro(d_path, length=0, number=0, normal=True, rate=[0, 0, 0], enc=False, enc_step=28):
    # 获得该文件夹下所有.mat文件名
    filenames = os.listdir(d_path)

    def capture(original_path):
        files = {}
        for i in filenames:
            # 文件路径
            file_path = os.path.join(d_path, i)
            file = loadmat(file_path)
            file_keys = file.keys()
            for key in file_keys:
                if 'DE' in key:
                    files[i] = file[key].ravel()
        return files

    def slice_enc(data, slice_rate= rate[1]):
        keys = data.keys()
        Train_Samples = {}
        Test_Samples = {}
        for i in keys:
            slice_data = data[i]

            all_lenght = len(slice_data)
            # end_index = int(all_lenght * (1 - slice_rate))
            samp_train = int(number * (1 - slice_rate))  # 1000(1-0.3)
            Train_sample = []
            Test_Sample = []

            for j in range(samp_train):
                sample = slice_data[j * 150: j * 150 + length]
                Train_sample.append(sample)

            # 抓取测试数据
            for h in range(number - samp_train):
                sample = slice_data[samp_train * 150 + length + h * 150: samp_train * 150 + length + h * 150 + length]
                Test_Sample.append(sample)
            Train_Samples[i] = Train_sample
            Test_Samples[i] = Test_Sample
        return Train_Samples, Test_Samples

    # 仅抽样完成,打标签
    def add_labels(train_test):
        X = []
        Y = []
        label = 0
        for i in filenames:
            x = train_test[i]
            X += x
            lenx = len(x)
            Y += [label] * lenx
            label += 1
        return X, Y

    def scalar_stand(Train_X, Test_X):
        # 用训练集标准差标准化训练集以及测试集
        data_all = np.vstack((Train_X, Test_X))
        scalar = preprocessing.StandardScaler().fit(data_all)
        Train_X = scalar.transform(Train_X)
        Test_X = scalar.transform(Test_X)
        return Train_X, Test_X

    def valid_test_slice(Test_X, Test_Y):

        test_size = rate[2] / (rate[1] + rate[2])
        ss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
        Test_Y = np.asarray(Test_Y, dtype=np.int32)

        for train_index, test_index in ss.split(Test_X, Test_Y):
            X_valid, X_test = Test_X[train_index], Test_X[test_index]
            Y_valid, Y_test = Test_Y[train_index], Test_Y[test_index]

            return X_valid, Y_valid, X_test, Y_test

    # 从所有.mat文件中读取出数据的字典
    data = capture(original_path=d_path)
    # 将数据切分为训练集、测试集
    train, test = slice_enc(data)
    # 为训练集制作标签,返回X,Y
    Train_X, Train_Y = add_labels(train)
    # 为测试集制作标签,返回X,Y
    Test_X, Test_Y = add_labels(test)
    # for i in Test_X:
    #     print(i.shape)
    # for i in Train_X:
    #     print(i.shape)
    # Train_X = np.stack(Train_X,axis=0)
    # Test_X = np.stack(Test_X,axis=0)
    # print(Train_X.shape,Test_X.shape)

    # 训练数据/测试数据 是否标准化.
    if normal:
        Train_X, Test_X = scalar_stand(Train_X, Test_X)

    # 将测试集切分为验证集和测试集.
    # Valid_X, Valid_Y, Test_X, Test_Y = valid_test_slice(Test_X, Test_Y)
    return Train_X, Train_Y,  Test_X, Test_Y


num_classes = 10  # 样本类别
length = 224*224  # 样本长度
number = 140  # 每类样本的数量
normal = True  # 是否标准化
rate = [0.5, 0.25, 0.25]  # 测试集验证集划分比例




class BearingDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def create_dataloader(data, labels, batch_size=32, shuffle=True, num_workers=0):
    dataset = BearingDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return dataloader

# 使用前面定义的函数处理数据
path = './data12k'  # 注意路径格式可能需要根据您的操作系统调整
x_train, y_train,  x_test, y_test = prepro(
    d_path=path,
    length=112*112,  # 样本长度
    number=250,  # 每类样本的数量
    normal=True,  # 是否标准化
    rate=[0.8, 0.2]  # 测试集验证集划分比例
)

# 创建 DataLoader
train_loader = create_dataloader(x_train, y_train, batch_size=32, shuffle=True, num_workers=0)
test_loader = create_dataloader(x_test, y_test, batch_size=32, shuffle=False, num_workers=0)


class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.Conv1 = nn.Conv2d(1, 24, 15, 3, 2)  # [1, 48, 107, 107]
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2)  # [1, 24, 35, 35]

        self.Conv2 = nn.Conv2d(24, 64, 5, 1, 2)  # [1, 64, 37, 37]
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(2)  # [1, 64, 18, 18]

        self.Conv3 = nn.Conv2d(64, 96, 2, 1, 1)  # [1, 96, 19, 19]
        self.relu3 = nn.ReLU()
        self.Conv4 = nn.Conv2d(96, 96, 2, 1, 1)  # [1, 96, 20, 20]
        self.relu4 = nn.ReLU()
        self.Conv5 = nn.Conv2d(96, 64, 2, 1, 1)  # [1, 64, 21, 21]
        self.relu5 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d(3)  # [1, 64, 7, 7]

        self.Dro1 = nn.Dropout(p=0.5)
        self.flatten = nn.Flatten()
        self.line1 = nn.Linear(64 * 3 * 3, 1000)
        self.relu6 = nn.ReLU()
        self.Dro2 = nn.Dropout(p=0.5)
        self.line2 = nn.Linear(1000, 1000)
        self.relu7 = nn.ReLU()
        self.line3 = nn.Linear(1000, 500)
        self.line4 = nn.Linear(500, 10)

    def forward(self, x):
        x = self.Conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.Conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.Conv3(x)
        x = self.relu3(x)
        x = self.Conv4(x)
        x = self.relu4(x)
        x = self.Conv5(x)
        x = self.relu5(x)
        x = self.maxpool3(x)

        x = self.Dro1(x)
        x = self.flatten(x)
        x = self.line1(x)
        x = self.relu6(x)
        x = self.Dro2(x)
        x = self.line2(x)
        x = self.relu7(x)
        x = self.line3(x)
        x = self.line4(x)

        return x


def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs.view(-1,1,112,112))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def evaluate_model(model, data_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs.view(-1,1,112,112))
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    return accuracy, precision, recall, f1_score


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AlexNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 50

    # 假设 train_loader 和 test_loader 已经被创建
    best_acc=0
    for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, criterion, optimizer, device)
        print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}')
        accuracy, precision, recall, f1_score = evaluate_model(model, test_loader, device)
        print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1_score:.4f}')
        if best_acc<accuracy:
            best_acc = accuracy
            torch.save(model.state_dict(), 'best_alexnet.pth')

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1679071.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

量化交易:日内回转交易策略

哈喽&#xff0c;大家好&#xff0c;我是木头左&#xff01; 引言 本文将介绍日内回转交易策略的原理&#xff0c;并通过Python代码示例展示如何在掘金平台实现该策略。本文将深入探讨一种基于1分钟MACD&#xff08;Moving Average Convergence Divergence&#xff0c;即移动平…

C++ LeetCode 刷题经验、技巧及踩坑记录【三】

C LeetCode 刷题经验、技巧及踩坑记录【三】 前言vector 计数vector 逆序vector 删除首位元素vector二维数组排序vector二维数组初始化C 不同进制输出C 位运算C lower_bound()C pairC stack 和 queue 前言 记录一些小技巧以及平时不熟悉的知识。 vector 计数 计数 //记录与首…

C# Winform+Halcon结合标准视觉工具

介绍 winform与halcon结合标准化工具实例 软件架构 软件架构说明 基于NET6 WINFORMHALCON 实现标准化视觉检测工具 集成相机通讯 集成PLC通讯 TCP等常见通讯 支持常见halcon算子 图形采集blob分析高精度匹配颜色提取找几何体二维码提取OCR识别等等 。。。 安装教程 …

MQTT_客户端安装_1.4

下载地址 MQTTX 下载 下一步直接安装即可 界面介绍

GhostNetV2 Enhance Cheap Operation with Long-Range Attention 论文学习

论文地址&#xff1a;https://arxiv.org/abs/2211.12905 代码地址&#xff1a;https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/ghostnetv2_pytorch 解决了什么问题&#xff1f; 在计算机视觉领域&#xff0c;深度神经网络在诸多任务上扮演着重要角色。为…

从源头到洞察:大数据时代的数据提取与分析实战指南

随着科技的飞速发展&#xff0c;大数据已经成为现代社会的核心驱动力之一。从商业决策到科学研究&#xff0c;从政策制定到个人生活&#xff0c;数据无处不在&#xff0c;影响着我们的每一个决策。然而&#xff0c;如何从海量的数据中提取有价值的信息&#xff0c;并转化为深刻…

LVM - Linux磁盘逻辑卷管理器概念讲解、实践及所遇到的问题

1、lvm概念 逻辑卷管理器(LogicalVolumeManager)本质上是一个虚拟设备驱动,是在内核中块设备和物理设备之间添加的一个新的抽象层次,它可以将几块磁盘(物理卷,PhysicalVolume)组合起来形成一个存储池或者卷组(VolumeGroup)。LVM可以每次从卷组中划分出不同大小的逻辑卷(Logi…

iOS 主要语言切换问题

前言 上架时需要把主要语言切换成英文&#xff0c;存储时一直提示“因为您必须先为使用这种语言的每个版本提供所有必需的截屏”错误。 错误截图 解决方案&#xff1a; 1、增加英文的截图去审核&#xff0c;审核过了再切换主要语言 官方文档出处 END.

uniapp小程序使用scroll-view组件实现上下左右滚动触发事件

在做uniapp开发小程序的时候&#xff0c;有一个需求是在一个表格区域里面可以上下左右滑动元素&#xff0c;并实现表头和左侧的标签联动效果&#xff0c;就想趣运动里面选择场地的效果一样&#xff0c;这里就用到了scroll-view组件&#xff0c;scroll-view官网文档地址&#xf…

安卓、iOS、iPad三端搞定,不再剧荒!

哈喽&#xff0c;各位小伙伴们好&#xff0c;我是给大家带来各类黑科技与前沿资讯的小武。 之前给大家推荐过各种看剧姿势&#xff0c;但很多苹果、平板端的小伙伴还是存在更好的需求体验&#xff0c;今天给大家推荐这款可以在安卓、iOS和平板上都能安装使用&#xff0c;不再剧…

音视频捕捉技术:LCC382 SDI采集卡深度解析

在日新月异的多媒体时代&#xff0c;高质量的音视频采集已成为众多领域不可或缺的一环。为此&#xff0c;灵卡科技精心打造了LCC382 —— 一款集高效性、灵活性与前沿技术于一身的SDI输入与环出、HDMI输出音视频采集卡&#xff0c;旨在满足从专业直播、视频会议到医疗影像、安防…

F5 Big-IP的一些查看命令

1 查看主机名&#xff0c;序列号&#xff0c;版本号 system —>configuration—>Device

Linux系统中pts和tty会话删除

一、背景 一台CentOS6.7主机存在iscsi盘&#xff0c;为了正常卸载此iscsi盘&#xff0c;需要先将所有相关会话退出使用该iscsi盘。 检查发现存在多个系统用户登录的情况。 二、问题 无法使用kill -9删除linux会话&#xff0c;提示信息为“-bash: kill: (16680) - Operation not…

开发利器 - docker 安装运行 mysql

本文选择安装的mysql版本为5.7 &#xff0c;安装环境 mac 1、查看镜像是否存在 docker search mysql:5.7 2、拉取镜像 docker pull mysql:5.7 3、运行镜像 docker run --name mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORDroot1234 -d mysql:5.7 --name&#xff1a;指定容器…

[AI]-(第1期):OpenAI-API调用

文章目录 一、OpenAI API中使用GPT-3.5-turbo模型充值方式使用模型计费方式价格说明相关限制和条款 二、接入一个OpenAI API流程1. 获取OpenAI API 密钥2. 集成ChatGPT到小程序3. 处理用户输入4. 调用OpenAI API5. 返回回复至小程序6. 持续优化7. Postman请求示例 三、通用AI客…

2024最新洗地机推荐,洗地机怎么选?热门品牌哪个最好用?

在现代生活中&#xff0c;忙碌的日常让家庭清洁变得更加繁重和耗时。然而&#xff0c;洗地机的引入彻底改变了这一状况。凭借其强大的清洁效果和简便的使用方式&#xff0c;洗地机能够迅速清除地面上的各种污垢&#xff0c;使清洁工作变得轻松自如。正因为如此&#xff0c;洗地…

windows编译opencv4.9

opencv很多人在windows上编译感觉特别麻烦&#xff0c;没有linux下方便&#xff0c;设定以下三点&#xff0c;我们几乎会无障碍。 1 安装cuda&#xff0c;cudnn 安装好cuda&#xff0c;cudnn&#xff0c;把cudnn的头文件&#xff0c;库等等拷贝到cuda的安装目录下面&#xff…

抖音电商发展受限,视频号反而成了短视频电商风口?这是为什么?

哈喽~我是电商月月 抖音小店发展的如火如荼间&#xff0c;视频号也正式推出了自己的电商平台 视频号小店的推出&#xff0c;引的众多商家讨论 很多人都觉得视频号的流量比不过抖音&#xff0c;玩互联网的人群【年轻群体】都集中在抖音上了&#xff0c;有抖音在&#xff0c;视…

惠普打印机无线网络连接设置

休息一下&#xff0c;灌个水。这次没多少内容&#xff0c;具体步骤惠普官网上都有&#xff0c;唯一增加的是对安装过程中踩的坑做了一个说明。 一&#xff0e;打印机无线网络连接设置步骤 惠普打印机设置无线网络连接&#xff0c;共16个步骤。 1. 在电脑上打开任意浏览器&am…

k8s证书续期

证书即将到期了如何进行证书续签 k8s版本V1.23.6 1.查看证书期限 kubeadm certs check-expiration如果证书即将到期&#xff0c;此处的天数应该是几天&#xff0c;在过期之前进行续期&#xff0c;保证集群的可用 2. 备份证书 避免出现问题可以回退 cp -r /etc/kubernetes …