0 导入库
import math
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
from src.modules.transformer import Block
from src.modules.prompt import Prompt
from src.modules.utils import (
FlattenHead,
PoolingHead,
RevIN,
)
1TEMPOConfig
1.1 构造函数
class TEMPOConfig:
"""
Configuration of a `TEMPO` model.
Args:
num_series: 时间序列的数量, N
input_len: 输入时间序列的长度, L
pred_len: 预测时间序列的长度, Y
block_size: 块的最大长度(openai gpt2 固定)
n_layer: Transformer 层的数量
n_head: 多头注意力机制中的头数量
n_embd: 嵌入维度的数量
patch_size: 块的大小,用于将输入时间序列分割成多个小块
patch_stride: 块的步幅,用于指定块之间的重叠程度
revin: 是否使用 RevIN(归一化和逆变换)
affine: 在 RevIN 中是否使用仿射变换
embd_pdrop:嵌入层的 dropout 率
resid_pdrop: 残差连接的 dropout 率
attn_pdrop: 注意力层的 dropout 率
head_type: 输出层的类型,可以是 FlattenHead 或 PoolingHead
head_pdtop: 输出层的 dropout 率
individual: 是否为每个组件使用独立的输出层
lora: 是否使用 LoRA(低秩近似)
lora_config: LoRA 的配置
model_type: 模型类型,默认为 gpt2
interpret: 是否输出组件以便解释
"""
num_series: int
input_len: int
pred_len: int
patch_size: int
patch_stride: int
block_size: int = None
n_layer: int = None
n_head: int = None
n_embd: int = None
revin: bool = True
affine: bool = True
embd_pdrop: float = 0.1
resid_pdrop: float = 0.1
attn_pdrop: float = 0.1
head_type: str = "flatten"
head_pdtop: float = 0.1
individual: bool = False
lora: bool = False
lora_config: dict = None
prompt_config: dict = None
#Prompt 模块的配置
model_type: str = "gpt2"
interpret: bool = False
1.2 todict
TEMPOConfig
类实例转换为一个字典
def todict(self):
return asdict(self)
'''
asdict 是 Python 的 dataclasses 模块提供的一个函数,用于将数据类实例转换为字典。
这个方法将当前实例的所有属性转换为字典键值对,并返回这个字典。
'''
1.3 __contains__
重载了 Python 的 __contains__
魔术方法,使得 TEMPOConfig
实例可以像字典一样使用 in
操作符来检查属性是否存在。
def __contains__(self, key):
return key in self.todict()
1.4 __getitem__
重载了 __getitem__
魔术方法,使得 TEMPOConfig
实例可以像字典一样通过键来获取属性值
def __getitem__(self, key):
return getattr(self, key)
1.5__setitem__
重载了 __setitem__
魔术方法,使得 TEMPOConfig
实例可以像字典一样通过键来设置属性值
def __setitem__(self, key, value):
setattr(self, key, value)
1.6 update
通过一个字典 config
更新 TEMPOConfig
实例的属性
def update(self, config: dict):
for k, v in config.items():
setattr(self, k, v)
2 TEMPO
class TEMPO(nn.Module):
"""
Notation:
B: 批次大小
N: 时间序列的数量
E: 嵌入维度
P: 块的数量
PS: patch的大小
L: 输入时间序列的长度
Y: 预测时间序列的长度
"""
models = ("gpt2",)
#支持的模型类型列表
head_types = ("flatten", "pooling")
#支持的输出层类型
params = {
"gpt2": dict(block_size=1024, n_head=12, n_embd=768),
}
'''
模型的参数,例如 "gpt2" 模型的块大小、注意力头数和嵌入维度等
'''