【时间序列篇】基于LSTM的序列分类-Pytorch实现 part1 案例复现

news2024/11/18 4:22:32

系列文章目录

【时间序列篇】基于LSTM的序列分类-Pytorch实现 part1 案例复现
【时间序列篇】基于LSTM的序列分类-Pytorch实现 part2 自有数据集构建
【时间序列篇】基于LSTM的序列分类-Pytorch实现 part3 化为己用

本篇文章是对已有一篇文章的整理归纳,并对文章中提及的模型用Pytorch实现。

文章目录

  • 系列文章目录
  • 前言
  • 一、任务问题和数据集
    • 1 任务问题
    • 2 数据集
    • 3 数据集读取并展示
  • 二、模型实现
    • 1 数据导入
    • 2 数据预处理
    • 3 数据集划分
    • 4 网络模型及实例化
    • 5 训练过程
  • 三、总结


前言

序列,可以是采样得到的信号样本,也可以是传感器数据。

对于序列分类任务,常用的思路有两种:
1、原理统计相关,分解序列的相关性质研究规律(人工设计特征,再分类)
2、数据挖掘思路,机器学习做特征工程,模型拟合(自动学习特征,再分类)

  • 人工设计特征方法:
    基于序列距离:计算距离进行分类(类别模板or聚类)
    基于统计特征:时序特征提取 (均值,方差,差分)再分类

  • 自动学习特征方法:
    深度学习端到端(RNN, LSTM)

本文通过LSTM来实现对序列信号的分类。


主要思想和代码框架来自参考文献[1]

一、任务问题和数据集

1 任务问题

人体运动估计:
传感器生成高频数据,对不同状态下采集的数据进行分类,可以识别其范围内对象的移动。通过设置多个传感器并对信号进行采样分析,可以识别物体的运动方向。

“ 室内用户运动预测 ”问题:
在该任务中,多个运动传感器被放置在不同房间中,目标基于运动传感器捕获的数据来识别个体是否已经移动穿过房间。

两个房间有四个运动传感器(A1,A2,A3,A4)。
下图说明了传感器在每个房间中的位置。
在这里插入图片描述
一个人可以沿着上图中所示的六个预定义路径中的任何一个移动。每个路径都生成一个 RSS 测量的轨迹样本,从轨迹的开始一直到标记点,在图中表示为 M。标记 M 对于所有运动都是相同的,因此不能仅仅根据在 M 处收集的 RSS 值来区分不同的路径。
该图还显示了所考虑的用户轨迹类型的简化说明,直线路径导致于空间变化,曲线路径导致空间不变。有在房间内移动和在房间之间移动两种类别。

2 数据集

文件含义
RSS_Position_dataset/dataset样本数据
RSS_Position_dataset/groups标签文件和组别文件(划分数据集)
RSS_Position_dataset/MovementAAL.jpg上面的示意图

数据集最重要的有316个csv文件:

  • 【dataset 文件夹】
    314 个MovementAAL csv文件,是序列样本,每个文件都包含与输入 RSS 数据的一个序列数据(每个文件记录一个用户轨迹)。该数据集包含314个序列数据(样本csv文件)。
    1个 MovementAAL_target.csv 文件,是每个MovementAAL文件对应的标签(类别)。每一个样本对应的类别,表明用户的轨迹是否会导致空间变化(例如房间的变化)。特别地,标签为+1与位置变化相关联,而标签为 -1与位置保留相关联。
  • 【groups 文件夹】
    MovementAAL_DatasetGroup.csv文件,用于划分数据集

3 数据集读取并展示

import pandas as pd
# ----------------------------------------------------#
#   路径指定,文件读取
# ----------------------------------------------------#
df1 = pd.read_csv("DATA/RSS_Position_dataset/dataset/MovementAAL_RSS_1.csv")
df2 = pd.read_csv("DATA/RSS_Position_dataset/dataset/MovementAAL_RSS_2.csv")

df1.head()  # 返回一个新的DataFrame或Series对象,默认返回前5行。
df1.shape  # 返回文件的size,不同文件的len(行数)不同

二、模型实现

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

1 数据导入

'''
/****************************************************/
    导入数据集
/****************************************************/
'''
# ----------------------------------------------------#
#   数据集样本
# ----------------------------------------------------#
path = "DATA/RSS_Position_dataset/dataset/MovementAAL_RSS_"
sequences = list()
for i in range(1, 315):  # 315为样本数
    file_path = path + str(i) + '.csv'
    df = pd.read_csv(file_path, header=0)
    values = df.values
    sequences.append(values)

# ----------------------------------------------------#
#   数据集标签
# ----------------------------------------------------#
targets = pd.read_csv('DATA/RSS_Position_dataset/dataset/MovementAAL_target.csv')
targets = targets.values[:, 1]

# ----------------------------------------------------#
#   数据集划分
# ----------------------------------------------------#
groups = pd.read_csv('DATA/RSS_Position_dataset/groups/MovementAAL_DatasetGroup.csv', header=0)
groups = groups.values[:, 1]

分析:

  1. 数据集样本:将所有的样本读入sequences列表中,列表长度为样本数,列表中每一个元素为一个样本。
  2. 数据集标签:targets 中存放。
  3. 数据集划分:数据集是在三对不同的房间中收集的,因此有三组。此信息可用于将数据集划分为训练集,测试集和验证集。

2 数据预处理

由于时间序列数据的长度不同,sequences列表中每个元素长度不一。无法直接在此数据集上构建模型。需要统一。原文中的思想是填充使相等。
这里是对样本,即sequences列表变量进行处理。

# ----------------------------------------------------#
#   Padding the sequence with the values in last row to max length
# ----------------------------------------------------#
# 函数用于填充和截断序列
def pad_truncate_sequences(sequences, max_len, dim=4, truncating='post', padding='post'):
    # 初始化一个空的numpy数组,用于存储填充后的序列
    padded_sequences = np.zeros((len(sequences), max_len, dim))
    for i, one_seq in enumerate(sequences):
        if len(one_seq) > max_len:  # 截断
            if truncating == 'pre':
                padded_sequences[i] = one_seq[-max_len:]
            else:
                padded_sequences[i] = one_seq[:max_len]
        else:  # 填充
            padding_len = max_len - len(one_seq)
            to_concat = np.repeat(one_seq[-1], padding_len).reshape(dim, padding_len).transpose()
            if padding == 'pre':
                padded_sequences[i] = np.concatenate([to_concat, one_seq])
            else:
                padded_sequences[i] = np.concatenate([one_seq, to_concat])
    return padded_sequences

# 使用自定义函数进行填充和截断
seq_len = 32
# truncate or pad the sequence to seq_len
final_seq = pad_truncate_sequences(sequences, max_len=seq_len, dim=4, truncating='post', padding='post')

对数据集来说,标签 +1/-1 不利于模型输出,变为 1/0。
这里是对标签,即targets类别变量进行处理。

# 设置标签从 +1/-1 ,变为 1/0
targets = np.array(targets)
final_targets = (targets+1)/2

3 数据集划分

# ----------------------------------------------------#
#   数据集划分
# ----------------------------------------------------#
# 将numpy数组转换为PyTorch张量
final_seq = torch.tensor(final_seq, dtype=torch.float)

# 划分样本为 训练集,验证集 和 测试集
train = [final_seq[i] for i in range(len(groups)) if groups[i] == 1]
validation = [final_seq[i] for i in range(len(groups)) if groups[i] == 2]
test = [final_seq[i] for i in range(len(groups)) if groups[i] == 3]
# 标签同理
train_target = [final_targets[i] for i in range(len(groups)) if groups[i] == 1]
validation_target = [final_targets[i] for i in range(len(groups)) if groups[i] == 2]
test_target = [final_targets[i] for i in range(len(groups)) if groups[i] == 3]

# 转换为PyTorch张量
train = torch.stack(train)
train_target = torch.tensor(train_target).long()

validation = torch.stack(validation)
validation_target = torch.tensor(validation_target).long()

test = torch.stack(test)
test_target = torch.tensor(test_target).long()

4 网络模型及实例化

'''
/****************************************************/
    网络模型
/****************************************************/
'''
# ----------------------------------------------------#
#   LSTM 模型
# ----------------------------------------------------#
class TimeSeriesClassifier(nn.Module):
    def __init__(self, n_features, hidden_dim=256, output_size=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_size)  # output_size classes

    def forward(self, x):
        x, _ = self.lstm(x)  # LSTM层
        x = x[:, -1, :]  # 只取LSTM输出中的最后一个时间步
        x = self.fc(x)  # 通过一个全连接层
        return x


# ----------------------------------------------------#
#   模型实例化 和 部署
# ----------------------------------------------------#
n_features = 4  # 根据你的特征数量进行调整
output_size = 2
model = TimeSeriesClassifier(n_features=n_features, output_size=output_size)

# 打印模型结构
print(model)

5 训练过程

'''
/****************************************************/
    训练过程
/****************************************************/
'''
# 设置训练参数
epochs = 100  # 训练轮数,根据需要进行调整
batch_size = 4  # 批大小,根据你的硬件调整

# DataLoader 加载数据集
# 将数据集转换为张量并创建数据加载器
train_dataset = torch.utils.data.TensorDataset(train, train_target)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

validation_dataset = torch.utils.data.TensorDataset(validation, validation_target)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=True)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

# 学习率和优化策略
learning_rate = 1e-3
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=5e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)  # 设置学习率下降策略


# ----------------------------------------------------#
#   训练
# ----------------------------------------------------#
def calculate_accuracy(y_pred, y_true):
    _, predicted_labels = torch.max(y_pred, 1)
    correct = (predicted_labels == y_true).float()
    accuracy = correct.sum() / len(correct)
    return accuracy


for epoch in range(epochs):
    model.train()  # 将模型设置为训练模式
    train_epoch_loss = []
    train_epoch_accuracy = []
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data  # 获取输入数据和标签
        optimizer.zero_grad()  # 清零梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, labels)
        loss.backward()  # 反向传播和优化
        optimizer.step()

        # 打印统计信息
        # train_epoch_loss.append(loss.item())
        # accuracy = calculate_accuracy(outputs, labels)
        # train_epoch_accuracy.append(accuracy.item())
        #
        # train_running_loss = np.average(train_epoch_loss)
        # train_running_accuracy = np.average(train_epoch_accuracy)
        #
        # if i % 10 == 9:  # 每10个批次打印一次
        #     print("--------------------------------------------")
        #     print(f'Epoch {epoch + 1}, Loss: {train_running_loss}, accuracy: {train_running_accuracy}')

    # Validation accuracy
    model.eval()
    valid_epoch_accuracy = []
    with torch.no_grad():
        for inputs, labels in validation_loader:  # Assuming validation_loader is defined
            outputs = model(inputs)

            accuracy = calculate_accuracy(outputs, labels)
            valid_epoch_accuracy.append(accuracy.item())
    # 计算平均精度
    valid_running_accuracy = np.average(valid_epoch_accuracy)
    print(f'Epoch {epoch + 1}, Validation Accuracy: {valid_running_accuracy:.4f}')

print('Finished Training')

三、总结

在验证集上的分类准确率最高才70%。emmm我猜是数据少。

CSDN: 进行时间序列分类实践–python实战

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1413713.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[ESP32]在Thonny IDE中,如何將MicroPython firmware燒錄到ESP32開發板中?

[ESP32 I MicroPython] Flash Firmware by Thonny(4.1.4) IDE 正常安裝流程,可參考上述影片。然而,本篇文章主要是紀錄安裝過程遇到的bug, 供未來查詢用,也一併供有需要的同好參考。 問題:安裝後,Thonny互動介面顯示一堆亂碼和co…

新建react项目,react-router-dom配置路由,引入antd

提示:reactrouter6.4版本,与reactrouter5.0的版本用法有区别,互不兼容需注意 文章目录 前言一、创建项目二、新建文件并引入react-router-dom、antd三、配置路由跳转四、效果五、遇到的问题六、参考文档总结 前言 需求:新建react项…

python-自动化篇-运维-监控-简单实例-道出如何使⽤Python进⾏系统监控?

如何使⽤Python进⾏系统监控? 使⽤Python进⾏系统监控涉及以下⼀般步骤: 选择监控指标: ⾸先,确定希望监控的系统指标,这可以包括 CPU 利⽤率、内存使⽤情况、磁盘空间、⽹络流量、服务可⽤性等。选择监控⼯具&#x…

tf卡被格式化怎么恢复里面的数据?恢复指南在此

在日常生活中,我们经常使用TF卡来存储各种数据,如照片、视频、文档等。然而,有时候我们会误将TF卡格式化,导致其中的数据丢失。为了挽救这些宝贵的数据,我们需要采取一些措施来进行恢复。本文将为你介绍如何恢复TF卡中…

架构整洁之道——价值维度与编程范式

1 设计与架构究竟是什么 结论:二者没有任何区别,一丁点区别都没有。 架构图里实际上包含了所有底层设计细节,这些细节信息共同支撑了顶层的架构设计,底层设计信息和顶层架构设计共同组成了整个架构文档。底层设计细节和高层架构信…

滑木块H5小游戏

欢迎来到程序小院 滑木块 玩法&#xff1a;点击木块横着的只能左右移动&#xff0c;竖着的只能上下移动&#xff0c; 移动到箭头的位置即过关&#xff0c;不同关卡不同的木块摆放&#xff0c;快去滑木块吧^^。开始游戏https://www.ormcc.com/play/gameStart/260 html <can…

JavaEE 网络编程

JavaEE 网络编程 文章目录 JavaEE 网络编程引子1. 网络编程-相关概念1.1 基本概念1.2 发送端和接收端1.3 请求和响应1.4 客户端和服务端 2. Socket 套接字2.1 数据包套接字通信模型2.2 流套接字通信模型2.3 Socket编程注意事项 3. UDP数据报套接字编程3.1 DatagramSocket3.2 Da…

pip 安装出现报错 SSLError(SSLError(“bad handshake

即使设置了清华源&#xff1a; pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simplepip 安装包不能配置清华源&#xff0c;出现报错: Retrying (Retry(total2, connectNone, readNone, redirectNone, statusNone)) after connection broken by ‘SSLE…

适用于 Windows 的 10 款免费 MP4 转 MP3 转换神器

每当我们观看歌曲或视频剪辑时&#xff0c;我们经常会想到将其转换为 MP3 格式&#xff0c;以便我们可以将其保存在设备上&#xff0c;因为它占用的空间更少。在将 MP4 转换为 MP3 的过程中&#xff0c;第一步也是最重要的一步是选择正确的工具来转换它&#xff0c;如果您想添加…

API网关-Apisix RPM包方式自动化安装配置教程

文章目录 前言一、简介1. etcd简介2. APISIX简介3. apisix-dashboard简介 二、Apisix安装教程1. 复制脚本2. 增加执行权限3. 执行脚本4. 浏览器访问5. 卸载Apisix 三、命令1. Apisix命令1.1 启动apisix服务1.2 停止apisix服务1.3 优雅地停止apisix服务1.4 重启apisix服务1.5 重…

SG-8506CA 可编程晶体振荡器 (SPXO)

输出: LV-PECL频率范围: 50MHz ~ 800MHz电源电压: 2.5V to 3.3V外部尺寸规格: 7.0 5.0 1.5mm (8引脚)特性:用户指定一个起始频率, 7-bit I2C 地址:用户可编程: I2C 接口:基频的高频晶体:低抖动PLL技术应用:OTN, BTS, 测试设备 规格&#xff08;特征&#xff09; *1 这包括初…

链表--543. 二叉树的直径/medium 理解度C

543. 二叉树的直径 1、题目2、题目分析3、复杂度最优解代码示例4、适用场景 1、题目 给你一棵二叉树的根节点&#xff0c;返回该树的 直径 。 二叉树的 直径 是指树中任意两个节点之间最长路径的 长度 。这条路径可能经过也可能不经过根节点 root 。 两节点之间路径的 长度 …

Python Flask与APScheduler构建简易任务监控

1. Flask Web Flask诞生于2010年&#xff0c;是用Python语言&#xff0c;基于Werkzeug工具箱编写的轻量级、灵活的Web开发框架&#xff0c;非常适合初学者或小型到中型的 Web 项目。 Flask本身相当于一个内核&#xff0c;其他几乎所有的功能都要用到扩展&#xff08;邮件扩展…

案例分享 | 助力数字化转型:嘉为科技项目管理平台上线

嘉为科技项目管理平台&#xff08;一期&#xff09;基于易趋&#xff08;EasyTrack&#xff09;进行实施&#xff0c;通过近一年的开发及试运行&#xff0c;现已成功交付上线、推广使用&#xff0c;取得了良好的应用效果。 1.关于广州嘉为科技有限公司&#xff08;以下简称嘉为…

外卖跑腿系统开发:构建高效、安全的服务平台

在当今快节奏的生活中&#xff0c;外卖跑腿系统的开发已成为技术领域的一个重要课题。本文将介绍如何使用一些常见的编程语言和技术框架&#xff0c;构建一个高效、安全的外卖跑腿系统。 1. 技术选择 在开始开发之前&#xff0c;我们需要选择适合的技术栈。常用的技术包括&a…

idea使用注释时如何不从行首开始

1、File—>setting 2、找到Editor&#xff0c;点Code Style 1.对于java注释设置 点java&#xff0c;然后选择Code Generation,去掉Line comment at first column,选择Add a space at comment start 2.对于xml注释设置 点XML&#xff0c;然后选择Code Generation,去掉Line c…

java-数组(以及jvm的内存分布)

文章目录 数组的基本概念数组的作用数组的创建以及初始化数组的创建数组的初始化 数组的使用数组中元素的访问遍历打印数组 数组是引用类型初始jvm的内存分布基本类型变量和引用类型变量的区别引用变量 认识null 数组的基本概念 数组可以看作是一种类型的集合我们在内存空间上…

Go 命令行解析 flag 包之快速上手

本篇文章是 Go 标准库 flag 包的快速上手篇。 概述 开发一个命令行工具&#xff0c;视复杂程度&#xff0c;一般要选择一个合适的命令行解析库&#xff0c;简单的需求用 Go 标准库 flag 就够了&#xff0c;flag 的使用非常简单。 当然&#xff0c;除了标准库 flag 外&#x…

Mac网线上网绿联扩展坞连接网线直接上网-无脑操作

声明&#xff1a;博主使用的绿联扩展坞 以下为绿联扩展坞Mac网线使用方法 1.首先需要下载电脑对应版本的驱动 直接点击即可下载 2. 下载好以后 解压 点进去 对应版本 博主直接使用最新的12-14 3. 安装包好了以后 会提示重启电脑 此时拔掉扩展坞 再重启动 拔掉扩展坞 再重启…

【Tomcat与网络1】史前时代—没有Spring该如何写Web服务

在前面我们介绍了网络与Java相关的问题&#xff0c; 最近在调研的时候发现这块内容其实非常复杂&#xff0c;涉及的内容多而且零碎&#xff0c;想短时间梳理出整个体系是不太可能的&#xff0c;所以我们还是继续看Tomcat的问题&#xff0c;后面有网络的内容继续补充吧。 目录 …