lstm实践

news2024/10/2 15:16:43

今年华为杯研究生数学建模的C题第四问用到了lstm,这里配合代码简要地讲一下。

数据类型

磁通密度是一个时序数据,包含了一个周期内的磁通密度变化,我们需要对它进行降维,但PCA是不合适的,因为PCA主要关注数据的方差,无法有效捕捉周期性数据的重要特征,而磁通密度是周期性变化的。

在自然语言处理领域中,LSTM可以捕获序列内部元素之间的关联性,并且其隐藏层可以包含前序序列的信息。最后一层的隐藏层就包含了整个序列的信息,所以我们可以将最后一层的隐藏层作为降维后的向量。

我们选择LSTM对1024 维的磁通密度进行降维,具体做法是:训练时对一个周期进行切片,使用LSTM预测切片的下一时刻的磁通密度;降维时使用整个周期,获取最后一 层的hidden state作为该样本的磁通密度特征。

代码

1.数据处理

import pandas as pd
import torch as pt
import os
os.chdir('/home/burger/math/')


df1 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料1')
df2 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料2')
df3 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料3')
df4 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料4')

collom_name = [i for i in range(1,1024)]
B1 = df1[['0(磁通密度B,T)']+collom_name]
B2 = df2[['0(磁通密度,T)']+collom_name]
B3 = df3[['0(磁通密度B,T)']+collom_name]
B4 = df4[['0(磁通密度B,T)']+collom_name]
print(B1.head())

B1_t = pt.tensor(B1.values)
B2_t = pt.tensor(B2.values)
B3_t = pt.tensor(B3.values)
B4_t = pt.tensor(B4.values)
B = pt.cat((B1_t, B2_t, B3_t, B4_t), 0)
print(B.shape)

def create_dataset(data, time_step=64):  
    x, y = [], []  
    for i in range(0, data.shape[1] - time_step, 32):  
        a = data[:, i:(i + time_step)]  
        x.append(a)  
        y.append(data[:, i + time_step])  
    return pt.concat(x).float(), pt.concat(y).float()

X, Y = create_dataset(B)
print(X.shape, Y.shape)
X, Y = X.unsqueeze(-1), Y.unsqueeze(-1)
print(X.shape, Y.shape)

 2.模型定义

import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm


# LSTM模型定义
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])  # 只取最后一个时间步的输出
        return out
    
    def embedding(self, x):
        _, hid_cell = self.lstm(x)
        return hid_cell[0]

3.训练

input_size = 1
hidden_size = 1
output_size = 1
num_layers = 1
num_epochs = 5
batch_size = 2048
gpu = 6
train_dataset = TensorDataset(X.to(gpu), Y.to(gpu))
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

model = LSTMModel(input_size, hidden_size, output_size, num_layers).to(gpu)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def train(model, dataloader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for inputs, labels in tqdm(dataloader, unit='batch'):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
train(model, train_loader, criterion, optimizer, num_epochs)

4.降维

df_san = pd.read_excel('./data/附件三(测试集).xlsx')
B_san = df_san[['0(磁通密度B,T)']+collom_name]
B_san_t = pt.tensor(B_san.values)
B_san_t = B_san_t.unsqueeze(-1).float().to(gpu)
print(B_san_t.shape)

emb_dataset = TensorDataset(B_san_t)
emb_loader = DataLoader(dataset=emb_dataset, batch_size=400, shuffle=False)

def embedding(model, dataloader):
    model.eval()
    embeddings = []
    for inputs in tqdm(dataloader, unit='batch'):
        outputs = model.embedding(inputs[0])
        embeddings.append(outputs)
    return pt.cat(embeddings).cpu().detach().numpy()

embeddings = embedding(model, emb_loader)
print(embeddings.shape)
embeddings = embeddings.reshape(-1)
print(embeddings.shape)

embeddings_df = pd.DataFrame({
    '磁通密度编码': embeddings,
})
embeddings_df.to_excel('./data/磁通密度编码1.xlsx', index=False)

有问题欢迎在评论区讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2184615.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何对大模型的回答置信度做出判断

大模型的回答置信度,特别是像 GPT 模型这类基于生成式预训练模型的系统,是一个高度复杂的概念。置信度(confidence)通常指模型在给定输出上有多大的确定性,反映的是模型对其生成的答案有多“确信”。这种置信度既可以被…

【STM32-HAL库】自发电型风速传感器(使用STM32F407ZGT6)(附带工程下载链接)

一、自发电型风速传感器介绍 自发电型风速传感器,也称为风力发电型风速传感器或无源风速传感器,是一种不需要外部电源即可工作的风速测量设备。这种传感器通常利用风力来驱动内部的发电机构,从而产生电能来供电测量风速的传感器部分。以下是自…

从u盘直接删除的文件能找回吗 U盘文件误删除如何恢复

U盘上的文件被删除并不意味着它们立即消失。事实上,删除操作只是将文件从文件系统的目录中移除,并标记可用空间。这意味着在文件被覆盖之前,它们仍然存在于存储介质上。因此,只要文件没有被新的数据覆盖,我们就有机会恢…

一本应用《软件方法》的书《软件需求分析和设计实践指南》

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 昨天看到了韩雪燕、李楠等老师写的《软件需求分析和设计实践指南》,前言提到了我。特别说明的是,这个书我自己看到的,韩老师等之前也未和我提过--这…

电子采购招投标比价供应商在线询价定标审批管理系统(源码)

前言: 随着互联网和数字技术的不断发展,企业采购管理逐渐走向数字化和智能化。数字化采购平台作为企业采购管理的新模式,能够提高采购效率、降低采购成本、优化供应商合作效率,已成为企业实现效益提升的关键手段。系统获取在文末…

前端组件化开发

假设这个页面是vue开发的,如果一整个页面都是编写在一个vue文件里面,后期不好维护,会特别的庞大,那么如何这个时候需要进行组件化开发。组件化开发后必然会带来一个问题需要进行组件之间的通信。组要是父子组件之间通信&#xff0…

[Linux]从零开始的网站搭建教程

一、谁适合本次教程 学习Linux已经有一阵子了,相信大家对LInux都有一定的认识。本次教程会教大家如何在Linux中搭建一个自己的网站并且实现内网访问。这里我们会演示在Windows中和在Linux中如何搭建自己的网站。当然,如果你没有Linux的基础,这…

【一篇文章理解Java中多级缓存的设计与实现】

文章目录 一.什么是多级缓存?1.本地缓存2.远程缓存3.缓存层级4.加载策略 二.适合/不适合的业务场景1.适合的业务场景2.不适合的业务场景 三.Redis与Caffine的对比1. 序列化2. 进程关系 四.各本地缓存性能测试对比报告(官方)五.本地缓存Caffine如何使用1. 引入maven依…

陶瓷4D打印有挑战,水凝胶助力新突破,复杂结构轻松造

大家好!今天要和大家聊聊一项超酷的技术突破——《Direct 4D printing of ceramics driven by hydrogel dehydration》发表于《Nature Communications》。我们都知道4D打印很神奇,能让物体随环境变化而改变形状。但陶瓷因为太脆太硬,4D打印一…

java中创建不可变集合

一.应用场景 二.创建不可变集合的书写格式(List,Set,Map) List集合 package com.njau.d9_immutable;import java.util.Iterator; import java.util.List;/*** 创建不可变集合:List.of()方法* "张三","李四","王五…

鸿蒙开发选择表情

鸿蒙开发选择表情 动态评论和聊天信息都需要用到表情,鸿蒙是没有提供的,得自己做 一、思路: 用表情字符显示表情,类似0x1F600代表笑脸 二、效果图: 三、关键代码: // 联系:893151960 Colum…

蓝桥杯【物联网】零基础到国奖之路:十五. 扩展模块之双路ADC

蓝桥杯【物联网】零基础到国奖之路:十五. 扩展模块之双路ADC 第一节 硬件解读第二节 CubeMX配置第三节 代码编写 第一节 硬件解读 STM32的ADC是12位,通过硬件过采样扩展到16位,模数转换器嵌入到STM32L071xx器件中。有16个外部通道和2个内部通道&#xf…

PDF阅读器工具集萃:满足你的多样需求

现在阅读书籍大部分都喜欢电子书的形式了吧,因为小小的一个设备就能存下上万本书。从流传程度来说PDF无疑是一个使用最广的格式。除了福昕PDF阅读器阅读之外还有哪些好用的阅读工具呢/?今天我们一起来探讨一下吧。 1.福昕阅读器 链接一下>>www.f…

css3-----2D转换、动画

2D 转换(transform) 转换(transform)是CSS3中具有颠覆性的特征之一,可以实现元素的位移、旋转、缩放等效果 移动:translate旋转:rotate缩放:scale 二维坐标系 2D 转换之移动 trans…

SysML案例-清朝、火星人入侵地球

DDD领域驱动设计批评文集>> 《软件方法》强化自测题集>> 《软件方法》各章合集>> 以下图形摘自Jon Holt和Simon Perry的SysML for Systems Engineering。 案例素材来自H. G. Wells在1898年(没错,清朝)出版的The War of…

Netty系列-7 Netty编解码器

背景 netty框架中,自定义解码器的起点是ByteBuf类型的消息, 自定义编码器的终点是ByteBuf类型。 1.解码器 业务解码器的起点是ByteBuf类型 netty中可以通过继承MessageToMessageEncoder类自定义解码器类。MessageToMessageEncoder继承自ChannelInboundHandlerAdap…

用于高频交易预测的最优输出LSTM

用于高频交易预测的最优输出LSTM J.P.Morgan的python教程 Content 本文提出了一种改进的长短期记忆(LSTM)单元,称为最优输出LSTM(OPTM-LSTM),用于实时选择最佳门或状态作为最终输出。这种单元采用浅层拓…

CSS 盒子属性

1. 盒子模型组成 1.1 边框属性 1.1.1 四边分开写 1.1.2 合并线框 1.1.3 边框影响盒子大小 1.2 内边距 注意: 1.3 外边距 1.3.1 嵌套块元素垂直外边距的塌陷 1.4 清除内外边距 1.5 总结

使用YOLO11训练自己的数据集【下载模型】-【导入数据集】-【训练模型】-【评估模型】-【导出模型】

目录 前言:一、下载模型二、导入数据集三、训练自己的数据集四、验证数据集五、测试数据集 前言: YOLO11于2024年9月30日由YOLOv8团队正式发布,为了让我们能够趁热打铁早发论文,接下来让我们仔细研究一下如何使用YOLO11训练自己的…

通信协议感悟

本文结合个人所学,简要讲述SPI,I2C,UART通信的特点,限制。 1.同步通信 UART,SPI,I2C三种串行通讯方式,SPI功能引脚为CS,CLK,MOSI,MISO;I2C功能引…