论文辅助笔记:Tempo 之 model.py

news2025/1/10 15:17:22

 0 导入库

import math
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn

from src.modules.transformer import Block
from src.modules.prompt import Prompt
from src.modules.utils import (
    FlattenHead,
    PoolingHead,
    RevIN,
)


1TEMPOConfig

1.1 构造函数

class TEMPOConfig:
    """
    Configuration of a `TEMPO` model.

    Args:
        num_series: 时间序列的数量, N 
        input_len: 输入时间序列的长度, L
        pred_len: 预测时间序列的长度, Y
        block_size: 块的最大长度(openai gpt2 固定)
        n_layer: Transformer 层的数量
        n_head: 多头注意力机制中的头数量
        n_embd: 嵌入维度的数量
        patch_size: 块的大小,用于将输入时间序列分割成多个小块
        patch_stride: 块的步幅,用于指定块之间的重叠程度
        revin: 是否使用 RevIN(归一化和逆变换)
        affine: 在 RevIN 中是否使用仿射变换
        embd_pdrop:嵌入层的 dropout 率
        resid_pdrop: 残差连接的 dropout 率
        attn_pdrop: 注意力层的 dropout 率
        head_type: 输出层的类型,可以是 FlattenHead 或 PoolingHead
        head_pdtop: 输出层的 dropout 率
        individual: 是否为每个组件使用独立的输出层
        lora: 是否使用 LoRA(低秩近似)
        lora_config: LoRA 的配置
        model_type: 模型类型,默认为 gpt2
        interpret: 是否输出组件以便解释
    """

    num_series: int
    input_len: int
    pred_len: int
    patch_size: int
    patch_stride: int
    block_size: int = None
    n_layer: int = None
    n_head: int = None
    n_embd: int = None
    revin: bool = True
    affine: bool = True
    embd_pdrop: float = 0.1
    resid_pdrop: float = 0.1
    attn_pdrop: float = 0.1
    head_type: str = "flatten"
    head_pdtop: float = 0.1
    individual: bool = False
    lora: bool = False
    lora_config: dict = None
    prompt_config: dict = None
    #Prompt 模块的配置
    model_type: str = "gpt2"
    interpret: bool = False

1.2  todict

TEMPOConfig 类实例转换为一个字典

def todict(self):
    return asdict(self)

'''
asdict 是 Python 的 dataclasses 模块提供的一个函数,用于将数据类实例转换为字典。

这个方法将当前实例的所有属性转换为字典键值对,并返回这个字典。
'''

1.3 __contains__

重载了 Python 的 __contains__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样使用 in 操作符来检查属性是否存在。

def __contains__(self, key):
    return key in self.todict()

1.4 __getitem__

重载了 __getitem__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样通过键来获取属性值

def __getitem__(self, key):
    return getattr(self, key)

1.5__setitem__

重载了 __setitem__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样通过键来设置属性值

def __setitem__(self, key, value):
    setattr(self, key, value)

1.6 update

通过一个字典 config 更新 TEMPOConfig 实例的属性

def update(self, config: dict):
    for k, v in config.items():
        setattr(self, k, v)

2 TEMPO

class TEMPO(nn.Module):
    """
    Notation:
        B: 批次大小
        N: 时间序列的数量
        E: 嵌入维度
        P: 块的数量
        PS: patch的大小
        L: 输入时间序列的长度
        Y: 预测时间序列的长度
    """

    models = ("gpt2",)
    #支持的模型类型列表

    head_types = ("flatten", "pooling")
    #支持的输出层类型

    params = {
        "gpt2": dict(block_size=1024, n_head=12, n_embd=768),
    }
    '''
    模型的参数,例如 "gpt2" 模型的块大小、注意力头数和嵌入维度等
    '''

2.1 __init__

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1643078.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LabVIEW鸡蛋品质智能分级系统

LabVIEW鸡蛋品质智能分级系统 随着现代农业技术的飞速发展,精确、高效的农产品质量控制已成为行业的重要需求。其中,鸡蛋作为日常膳食中不可或缺的重要组成部分,其品质直接关系到消费者的健康与满意度。本文设计并实现了一套基于LabVIEW的鸡…

docker私有仓库的registry

简介 Docker私有仓库的Registry是一个服务,主要用于存储、管理和分发Docker镜像。具体来说,Registry的功能包括: 存储镜像:Registry提供一个集中的地方来存储Docker镜像,包括镜像的层次结构和元数据。 版本控制&…

node应用部署运行案例

生产环境: 系统:linux centos 7.9 node版本:v16.14.0 npm版本:8.3.1 node应用程序结构 [rootRainYun-Q7c3pCXM wiki]# dir assets config.yml data LICENSE node_modules nohup.out output.log package.json server wiki.log [rootRainYun-Q7c…

使用MATLAB/Simulink点亮STM32开发板LED灯

使用MATLAB/Simulink点亮STM32开发板LED灯-笔记 一、STM32CubeMX新建工程二、Simulink 新建工程三、MDK导入生成的代码 一、STM32CubeMX新建工程 1. 打开 STM32CubeMX 软件,点击“新建工程”,选择中对应的型号 2. RCC 设置,选择 HSE(外部高…

单链表式并查集

如果用暴力算法的话&#xff0c;那么会直接超时&#xff0c;我们要学会用并查集去记录下一个空闲的位置 #include<bits/stdc.h> using namespace std;const int N 100005;int n; int fa[N]; int a[N];int find(int x) {if (fa[x] x) {return x;}fa[x] find(fa[x]);re…

ChatGPT DALL-E绘图,制作各种表情包,实现穿衣风格的自由切换

DALL-E绘图功能探索&#xff1a; 1、保持人物形象一致&#xff0c;适配更多的表情、动作 2、改变穿衣风格 3、小女孩的不同年龄段展示 4、不同社交平台的个性头像创作 如果不会写代码&#xff0c;可以问GPT。使用地址&#xff1a;我的GPT4 视频&#xff0c;B站会发&#…

Leetcode—422. 有效的单词方块【简单】Plus

2024每日刷题&#xff08;126&#xff09; Leetcode—422. 有效的单词方块 实现代码 class Solution { public:bool validWordSquare(vector<string>& words) {int row words.size();for(int i 0; i < row; i) {// 当前这一行的列数int col words[i].length(…

网络基础-网络设备介绍

本系列文章主要介绍思科、华为、华三三大厂商的网络设备 网络设备 网络设备是指用于构建和管理计算机网络的各种硬件设备和设备组件。以下是常见的网络设备类型&#xff1a; 路由器&#xff08;Router&#xff09;&#xff1a;用于连接不同网络并在它们之间转发数据包的设备…

k8s调度原理以及自定义调度器

kube-scheduler 是 kubernetes 的核心组件之一&#xff0c;主要负责整个集群资源的调度功能&#xff0c;根据特定的调度算法和策略&#xff0c;将 Pod 调度到最优的工作节点上面去&#xff0c;从而更加合理、更加充分的利用集群的资源&#xff0c;这也是我们选择使用 kubernete…

「Node.js」ESModule 与 CommonJS 的 区别

前言 Node.js支持两种模块系统&#xff1a;CommonJS 和 ESModules&#xff08;ESM&#xff09;&#xff0c;它们在语法和功能上有一些不同。 CommonJS (CJS) CommonJS 是 Node.js 最早支持的模块规范&#xff0c;由于它的出现在ES6之前&#xff0c;因此采取的是同步加载模块…

Linux Ubuntu 开机自启动浏览器

终端输入命令&#xff1a;gnome-session-properties 打开启动设置 如果提示&#xff1a;Command ‘gnome-session-properties’ not found, but can be installed with: apt install gnome-startup-applications 则执行&#xff1a;apt install gnome-startup-applications安装…

用pyecharts完成综合案例之全球GDP动态可视化统计图

综合案例之全球GDP 所用csv文档下载链接如下&#xff1a;https://download.csdn.net/download/qq_42707739/12621102?ops_request_misc%257B%2522request%255Fid%2522%253A%2522171488482816800184124883%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fdownloa…

机器学习周报第40周

目录 摘要Abstract一、文献阅读1.1 摘要1.2 论文背景1.3 论文模型1.3.1 模型概述1.3.2 模型细节 1.4 模型精度 二、论文代码2.1 rtdetr.py2.2 backbone模块2.3 AIFI2.4 CCFM 总结 摘要 本周&#xff0c;我深入研读了RT-DETR&#xff08;实时目标检测变换器&#xff09;论文&am…

【数据结构】初识数据结构

引入&#xff1a; 哈喽大家好&#xff0c;我是野生的编程萌新&#xff0c;首先感谢大家的观看。数据结构的学习者大多有这样的想法&#xff1a;数据结构很重要&#xff0c;一定要学好&#xff0c;但数据结构比较抽象&#xff0c;有些算法理解起来很困难&#xff0c;学的很累。我…

中仕公考:哪些情况不能考公务员?

1.年龄不符合 主要分两类【一类是未成年人&#xff0c;另一类是超龄人员】 具体来讲:年龄一般为18周岁以上、35周岁以下 (2024国考标准是1987年10月至2005年10月期间出生&#xff09; 对于2024年应届硕士、博士研究生(非在职人员)放宽到40周岁以下(2024国考标准是1982年10月以后…

【Conda】解决使用清华源创建虚拟环境不成功问题

文章目录 问题描述&#xff1a;清华源创建不成功解决步骤1 添加官方源步骤2 删除C:/user/你的用户名/的 .condarc 文件步骤3 再次创建 问题描述&#xff1a;清华源创建不成功 本地配置了清华源&#xff0c;但是在创建虚拟环境时不成功&#xff0c;报错如下。 图片若看不清&…

Docker使用进阶篇

文章目录 1 前言2 使用Docker安装常用镜像示例2.1 Docker安装RabbitMQ2.2 Docker安装Nacos2.3 Docker安装xxl-job&#xff08;推荐该方式构建&#xff09;2.4 Docker安装redis2.5 Docker安装mysql 1 前言 上一篇介绍了Docker的基础概念&#xff0c;带你 入门Docker&#xff0c…

初识webpack项目

新建一个空的工程 -> % mkdir webpack-project 为了方便追踪执行每一个命令&#xff0c;最终产生了哪些变更&#xff0c;将这个空工程初始化成git项目 -> % cd webpack-project/-> % git init Initialized empty Git repository in /Users/lixiang/frontworkspace/…

STM32微秒级别延时--F407--TIM1

基本配置&#xff1a; TIM1挂载在APB2总线上&#xff0c;150MHz经过15分频&#xff0c;得到10MHz计数频率&#xff0c;由于disable了自动重装载&#xff0c;所以只需要看下一次计数值是多少即可。 void TIM1_Delay_us(uint16_t us) //使用阻塞方式进行延时&#xff0c;ARR值不…