27、ResNet50处理STEW数据集,用于情感三分类+全备的代码

news2025/2/4 23:51:16

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

该数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms
 
 
# 预处理
data_transform = transforms.Compose([
    transforms.Resize((224,224)),           # 缩放图像
    transforms.ToTensor(),                  # 转为Tenso
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])


path =  r'C:\STEW\test'

for root,dirs,files in os.walk(path):
        print('root',root) #遍历到该目录地址
        print('dirs',dirs) #遍历到该目录下的子目录名 []
        print('files',files)  #遍历到该目录下的文件  []
 
def read_txt_files(path):
    # 创建文件名列表
    file_names = []
    # 遍历给定目录及其子目录下的所有文件
    for root, dirs, files in os.walk(path):
        # 遍历所有文件
        for file in files:
            # 如果是 .txt 文件,则加入文件名列表
            if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。
                file_names.append(os.path.join(root, file))
    # 返回文件名列表
    return file_names

class DogCat(Dataset):      # 数据处理
    def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等
 
        #['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']
        imgs = os.listdir(root)
 
        self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件
        self.transforms = data_transform                        # 图像预处理
 
    def __getitem__(self, index):       # 读取图片
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0 
        #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.
 
        data = Image.open(img_path)
 
        if self.transforms:     # 图像预处理
            data = self.transforms(data)
 
        return data,label
 
    def __len__(self):
        return len(self.imgs)
 
dataset = DogCat('./data/',transforms=True)
 
for img,label in dataset:
    print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''

import os
 
# 获取file_path路径下的所有TXT文本内容和文件名
def get_text_list(file_path):
    files = os.listdir(file_path)
    text_list = []
    for file in files:
        with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:
            text_list.append(f.read())
    return text_list, files
 
class ImageFolderCustom(Dataset):

    # 2. Initialize with a targ_dir and transform (optional) parameter
    def __init__(self, targ_dir: str, transform=None) -> None:

        # 3. Create class attributes
        # Get all image paths
        self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's
        # Setup transforms
        self.transform = transform
        # Create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(targ_dir)

    # 4. Make function to load images
    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path) 

    # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)

    # 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        img = self.load_image(index)
        class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpeg
        class_idx = self.class_to_idx[class_name]

        # Transform if necessary
        if self.transform:
            return self.transform(img), class_idx # return data, label (X, y)
        else:
            return img, class_idx # return data, label (X, y)
                  
import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transforms
 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"
 
# cifar-10进行测验

class Cutout(object):
    """Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length
 
    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)
 
        mask = np.ones((h, w), np.float32)
 
        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)
 
            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)
 
            mask[y1: y2, x1: x2] = 0.
 
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask
 
        return img
 
def load_data_cifar10(batch_size=128,num_workers=2):
    # 操作合集
    # Data augmentation
    train_transform_1 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转
        transforms.ToTensor(),
        transforms.Normalize(
            (0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)
        ),
        Cutout(1, 16),  # 务必放在ToTensor的后面
    ])
    train_transform_2 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std)
        )
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std)
        )
    ])
    # 训练集1
    trainset1 = tv.datasets.CIFAR10(
        root='data',
        train=True,
        download=False,
        transform=train_transform_1,
    )
    # 训练集2
    trainset2 = tv.datasets.CIFAR10(
        root='data',
        train=True,
        download=False,
        transform=train_transform_2,
    )
    # 测试集
    testset = tv.datasets.CIFAR10(
        root='data',
        train=False,
        download=False,
        transform=test_transform,
    )
    # 训练数据加载器1
    trainloader1 = torch.utils.data.DataLoader(
        trainset1,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=(torch.cuda.is_available())
    )
    # 训练数据加载器2
    trainloader2 = torch.utils.data.DataLoader(
        trainset2,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=(torch.cuda.is_available())
    )
    # 测试数据加载器
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=(torch.cuda.is_available())
    )
 
    return trainloader1,trainloader2,testloader
 
def main():
    start = time.time()
    batch_size = 128
    cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)
    model = resnet50().cuda()
    # model.load_state_dict(torch.load('_ResNet50.pth'))
    # 存在已保存的参数文件
    # model = nn.DataParallel(model,device_ids=[0,])  # 又套一层
    model = nn.DataParallel(model,device_ids=[0,1,2])
    loss = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters(),lr=0.001)
    for epoch in range(50):
        model.train()  # 训练时务必写
        loss_=0.0
        num=0.0
        # train on trainloader1(data augmentation) and trainloader2
        for i,data in enumerate(cifar_train1,0):
            x, label = data
            x, label = x.cuda(),label.cuda()
            # x
            p = model(x) #output
            l = loss(p,label) #loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            loss_ += float(l.mean().item())
            num+=1
        for i, data in enumerate(cifar_train2, 0):
            x, label = data
            x, label = x.cuda(), label.cuda()
            # x
            p = model(x)
            l = loss(p, label)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            loss_ += float(l.mean().item())
            num += 1
        model.eval()  # 评估时务必写
        print("loss:",float(loss_)/num)
        # test on trainloader2,testloader
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_train2:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.cuda(), label.cuda()
                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)
            acc_1 = total_correct / total_num
        # Test
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.cuda(), label.cuda()
                # [b, 10]
                logits = model(x) #output
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)
            acc_2 = total_correct / total_num
            print(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)
    # 保存时只保存model.module
    torch.save(model.module.state_dict(),'resnet50.pth')
    print("The interval is :",time.time() - start)
 
 
if __name__ == '__main__':
    main()

3、对你有帮助的话,给个关注吧~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1328342.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 中的内部类的定义

目录 一、成员内部类 二、静态内部类 三、局部内部类 四、匿名内部类 一、成员内部类 public class InnerClass {String name;private Integer age;static String hobby;/*** 成员内部类* 1、成员内部类中只能定义非静态属性和方法* 2、成员内部类中可以访问外部类的成员&a…

【深度学习实践】换脸应用dofaker本地部署

本文介绍了dofaker换脸应用的本地部署教程,dofaker支持windows、linux、cpu/gpu推理,不依赖于任何深度学习框架,是一个非常好用的换脸工具。 本教程的部署系统为windows 11,使用CPU推理。 注意: 1、请确保您的所有路…

晋江IP影视化频频折戟,网文陷入工业化困境

在影视行业进入IP时代的2023年,晋江文学城(以下简称晋江)IP影视化却迎来了大溃败。 2023年,晋江IP在影视行业依旧十分抢手,多部热门网文被影视化,其中不乏头部视频网站的S大制作,但播出效果却有…

.NET core 自定义过滤器 Filter 实现webapi RestFul 统一接口数据返回格式

之前写过使用自定义返回类的方式来统一接口数据返回格式,.Net Core webapi RestFul 统一接口数据返回格式-CSDN博客 但是这存在一个问题,不是所有接口会按照定义的数据格式返回,除非每个接口都返回我们自定义的类,这种实现起来不…

Mybatis-plus动态条件查询QueryWrapper的函数用法

目录 前言1. QueryWrapper2. 函数3. Demo 前言 原本都是在Mapper文件中修改,直到看到项目中使用了QueryWrapper这个函数,大致了解了用法以及功能,发现还可以! 对此此贴为科普帖以及笔记帖 1. QueryWrapper MyBatis-Plus 是 My…

你知道海外云手机可以用于外贸测评吗?

目前随着外贸行业的发展,像亚马逊、速卖通、eBay等海外电商平台越来越火热。在这些平台,过硬的产品质量、优秀的服务、合适的价格,再加上适量的跨境电商测评,很容易就能吸引不少的客户。那么如何利用海外云手机进行外贸测评&#…

rk3588多模型检测部署quickrun

quickrun 是一款rk3588 rknn多模型高效高并发部署软件 软件框架 采用session思想,可以定义多个session满足不同模型的义务需求。比如充电桩检测,垃圾分类,悬崖检测,模型共用一个摄像头,采用yolov5的模型。 采用消息…

【C语言】动态内存管理基础知识——动态通讯录,如何实现通讯录容量的动态化

引言 动态内存管理的函数有:malloc,calloc,ralloc,free,本文讲解动态内存函数和使用,如何进行动态内存管理,实现通讯录联系人容量的动态化,对常见动态内存错误进行总结。 ✨ 猪巴戒:个人主页✨ 所属专栏:《C语言进阶》…

【大模型】快速体验百度智能云千帆AppBuilder搭建知识库与小助手

文章目录 前言千帆AppBuilder什么是千帆AppBuilderAppBuilder能做什么 体验千帆AppBuilderJava知识库高考作文小助手 总结 前言 前天,在【百度智能云智算大会】上,百度智能云千帆AppBuilder正式开放服务。这是一个AI原生应用开发工作台,可以…

业务逻辑漏洞有哪些?漏洞攻击防御及代码示例

文章目录 简介危害成因攻击防御代码示例1. 未经验证的重要操作2. 认证绕过3. 逻辑时间窗口漏洞4. 负载测试漏洞 修复 业务逻辑漏洞是指软件或系统的逻辑设计上的缺陷,这些缺陷可能被攻击者利用,从而导致意料之外的行为。下面是对业务逻辑漏洞的简介、危害…

大数据技术基本功-数据采集

产品指南|DataScale自定义采集器功能介绍产品指南|开发 DataScale Collector​​​​​​​

ubuntu换源

1 首先备份Ubuntu源列表 sudo cp /etc/apt/sources.list /etc/apt/sources.list.backup 2 查看自己Ubuntu版本 命令 lsb_release -a precise为源里面的关键字,根据实际情况,自行修改 3 修改更新源 先删除原文件里面的内容 sudo gedit /etc/apt/sources.list 用下面内容替…

哈希表..

文章目录 1. 两数之和-力扣 1 题 1. 两数之和-力扣 1 题 思路: 循环遍历数组,拿到每个数字x以target-x作为key到map中查找 若没找到,将x 作为key,它的索引作为value 存入map 若找到了,返回 x 和它配对数的索引即可 …

【让云服务器更灵活】iptables转发tcp/udp端口请求

iptables转发tcp/udp端口请求 文章目录 前言一、路由转发涉及点二、转发如何配置本机端口转发到本机其它端口本机端口转发到其它机器 三、固化iptables总结 前言 路由转发是计算机网络中的一种重要概念,特别是在网络设备和系统之间。它涉及到如何处理和传递数据包&…

Redis(非关系型数据库)

Redis(非关系型数据库) 文章目录 Redis(非关系型数据库)认识Redis(Remote Dictionary Server)1.Redis的基本介绍2.Redis的应用场景2.1 取最新N个数据的操作2.2 排行榜应用,取TOP N操作2.3 需要精准设定过期时间的应用2.4 计数器应用2.5 Uniq 操作,获取某段时间所有数…

爬虫API|批量抓取电商平台商品数据,支持高并发

随着互联网的快速发展,电商平台如雨后春笋般涌现,为消费者提供了丰富的购物选择。然而,对于许多商家和数据分析师来说,如何快速、准确地获取电商平台上的商品数据成为了一个难题。为了解决这个问题,我们开发了一个爬虫…

GBASE南大通用数据库在Windows和Linux中创建数据源

Windows 中数据源信息可能存在于两个地方:在 Windows 注册表中(对 Windows 系统), 或在一个 DSN 文件中(对任何系统)。 如果信息在 Windows 注册表中,它叫做“机器数据源”。它可能是一个“用 …

基于若依的ruoyi-nbcio流程管理系统增加待办通知个性化设置

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 1、在每个节点可以设置扩展属性是todo的属性值,如下: 2、在需要审批或启动的时候获…

华为鸿蒙操作系统简介及系统架构分析(2)

接前一篇文章:华为鸿蒙操作系统简介及系统架构分析(1) 本文部分内容参考: 鸿蒙系统学习笔记(一) 鸿蒙系统介绍 特此致谢! 上一回对于华为的鸿蒙操作系统(HarmonyOS)进行了介绍并说明了其层次化…

韵达快递查询入口,一键将退回件筛选出来

批量查询韵达快递单号的物流信息,并将退回件一键筛选出来。 所需工具: 一个【快递批量查询高手】软件 韵达快递单号若干 操作步骤: 步骤1:运行【快递批量查询高手】软件,并登录 步骤2:点击主界面左上角的…