lag-llama源码解读(Lag-Llama: Towards Foundation Models for Time Series Forecasting)

news2025/4/17 19:35:16

Lag-Llama: Towards Foundation Models for Time Series Forecasting
文章内容:
时间序列预测任务,单变量预测单变量,基于Llama大模型,在zero-shot场景下模型表现优异。创新点,引入滞后特征作为协变量来进行预测。

获得不同频率的lag,来自glunoTS库里面的源码

def _make_lags(middle: int, delta: int) -> np.ndarray:
    """
    Create a set of lags around a middle point including +/- delta.
    """
    return np.arange(middle - delta, middle + delta + 1).tolist()

def get_lags_for_frequency(
    freq_str: str,
    lag_ub: int = 1200,
    num_lags: Optional[int] = None,
    num_default_lags: int = 7,
) -> List[int]:
    """
    Generates a list of lags that that are appropriate for the given frequency
    string.

    By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].
    Remaining lags correspond to the same `season` (+/- `delta`) in previous
    `k` cycles. Here `delta` and `k` are chosen according to the existing code.

    Parameters
    ----------

    freq_str
        Frequency string of the form [multiple][granularity] such as "12H",
        "5min", "1D" etc.

    lag_ub
        The maximum value for a lag.

    num_lags
        Maximum number of lags; by default all generated lags are returned.

    num_default_lags
        The number of default lags; by default it is 7.
    """

    # Lags are target values at the same `season` (+/- delta) but in the
    # previous cycle.
    def _make_lags_for_second(multiple, num_cycles=3):
        # We use previous ``num_cycles`` hours to generate lags
        return [
            _make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)
        ]

    def _make_lags_for_minute(multiple, num_cycles=3):
        # We use previous ``num_cycles`` hours to generate lags
        return [
            _make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)
        ]

    def _make_lags_for_hour(multiple, num_cycles=7):
        # We use previous ``num_cycles`` days to generate lags
        return [
            _make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)
        ]

    def _make_lags_for_day(
        multiple, num_cycles=4, days_in_week=7, days_in_month=30
    ):
        # We use previous ``num_cycles`` weeks to generate lags
        # We use the last month (in addition to 4 weeks) to generate lag.
        return [
            _make_lags(k * days_in_week // multiple, 1)
            for k in range(1, num_cycles + 1)
        ] + [_make_lags(days_in_month // multiple, 1)]

    def _make_lags_for_week(multiple, num_cycles=3):
        # We use previous ``num_cycles`` years to generate lags
        # Additionally, we use previous 4, 8, 12 weeks
        return [
            _make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)
        ] + [[4 // multiple, 8 // multiple, 12 // multiple]]

    def _make_lags_for_month(multiple, num_cycles=3):
        # We use previous ``num_cycles`` years to generate lags
        return [
            _make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)
        ]

    # multiple, granularity = get_granularity(freq_str)
    offset = to_offset(freq_str)
    # normalize offset name, so that both `W` and `W-SUN` refer to `W`
    offset_name = norm_freq_str(offset.name)

    if offset_name == "A":
        lags = []
    elif offset_name == "Q":
        assert (
            offset.n == 1
        ), "Only multiple 1 is supported for quarterly. Use x month instead."
        lags = _make_lags_for_month(offset.n * 3.0)
    elif offset_name == "M":
        lags = _make_lags_for_month(offset.n)
    elif offset_name == "W":
        lags = _make_lags_for_week(offset.n)
    elif offset_name == "D":
        lags = _make_lags_for_day(offset.n) + _make_lags_for_week(
            offset.n / 7.0
        )
    elif offset_name == "B":
        lags = _make_lags_for_day(
            offset.n, days_in_week=5, days_in_month=22
        ) + _make_lags_for_week(offset.n / 5.0)
    elif offset_name == "H":
        lags = (
            _make_lags_for_hour(offset.n)
            + _make_lags_for_day(offset.n / 24)
            + _make_lags_for_week(offset.n / (24 * 7))
        )
    # minutes
    elif offset_name == "T":
        lags = (
            _make_lags_for_minute(offset.n)
            + _make_lags_for_hour(offset.n / 60)
            + _make_lags_for_day(offset.n / (60 * 24))
            + _make_lags_for_week(offset.n / (60 * 24 * 7))
        )
    # second
    elif offset_name == "S":
        lags = (
            _make_lags_for_second(offset.n)
            + _make_lags_for_minute(offset.n / 60)
            + _make_lags_for_hour(offset.n / (60 * 60))
        )
    else:
        raise Exception("invalid frequency")

    # flatten lags list and filter
    lags = [
        int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub
    ]
    lags = list(range(1, num_default_lags + 1)) + sorted(list(set(lags)))

    return lags[:num_lags]

第一部分,生成以middle为中心,以delta为半径的区间[middle-delta,middle+delta] ,这很好理解,比如一周的周期是7天,周期大小在7天附近波动很正常。
在这里插入图片描述

第二部分,对于年月日时分秒这些不同的采样频率,采用不同的具体的函数来确定lags,其中有一个参数num_cycle,进一步利用了周期性,我们考虑间隔1、2、3、…num个周期的时间点之间的联系
在这里插入图片描述
原理类似于这张图,这种周期性的重复性体现在邻近的多个周期上

在这里插入图片描述

lag的用途

计算各类窗口大小

计算采样窗口大小

window_size = estimator.context_length + max(estimator.lags_seq) + estimator.prediction_length
    # Here we make a window slightly bigger so that instance sampler can sample from each window
    # An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler)
window_size = 10 * window_size
# We change ValidationSplitSampler to add min_past
    estimator.validation_sampler = ValidationSplitSampler(
        min_past=estimator.context_length + max(estimator.lags_seq),
        min_future=estimator.prediction_length,
    )

  1. 构建静态特征
lags = lagged_sequence_values(self.lags_seq, prior_input, input, dim=-1)#构建一个包含给定序列的滞后值的数组

static_feat = torch.cat((loc.abs().log1p(), scale.log()), dim=-1)
expanded_static_feat = unsqueeze_expand(
    static_feat, dim=-2, size=lags.shape[-2]
)

return torch.cat((lags, expanded_static_feat, time_feat), dim=-1), loc, scale

数据集准备过程

对每个数据集采样,window_size=13500,也挺离谱的

 train_data, val_data = [], []
        for name in TRAIN_DATASET_NAMES:
            new_data = create_sliding_window_dataset(name, window_size)
            train_data.append(new_data)

            new_data = create_sliding_window_dataset(name, window_size, is_train=False)
            val_data.append(new_data)

采样的具体过程,这里有个问题,样本数量很小的数据集,实际采样窗口大小小于设定的window_size,后续会如何对齐呢?

文章设置单变量预测单变量,所以样本进行了通道分离,同一样本的不同特征被采样为不同的样本

def create_sliding_window_dataset(name, window_size, is_train=True):
    #划分非重叠的滑动窗口数据集,window_size是对数据集采样的数量,对每个数据集只取前windowsize个样本
    # Splits each time series into non-overlapping sliding windows
    global_id = 0

    freq = get_dataset(name, path=dataset_path).metadata.freq#从数据集中获取时间频率
    data = ListDataset([], freq=freq)#创建空数据集
    dataset = get_dataset(name, path=dataset_path).train if is_train else get_dataset(name, path=dataset_path).test
    #获取原始数据集

    for x in dataset:
        windows = []
        #划分滑动窗口
        #target:滑动窗口的目标值
        #start:滑动窗口的起始位置
        #item_id,唯一标识符
        #feat_static_cat:静态特征数组
        for i in range(0, len(x['target']), window_size):
            windows.append({
                'target': x['target'][i:i+window_size],
                'start': x['start'] + i,
                'item_id': str(global_id),
                'feat_static_cat': np.array([0]),
            })
            global_id += 1
        data += ListDataset(windows, freq=freq)
    return data

合并数据集

# Here weights are proportional to the number of time series (=sliding windows)
        weights = [len(x) for x in train_data]
        # Here weights are proportinal to the number of individual points in all time series
        # weights = [sum([len(x["target"]) for x in d]) for d in train_data]

        train_data = CombinedDataset(train_data, weights=weights)
        val_data = CombinedDataset(val_data, weights=weights)
class CombinedDataset:
    def __init__(self, datasets, seed=None, weights=None):
        self._seed = seed
        self._datasets = datasets
        self._weights = weights
        n_datasets = len(datasets)
        if weights is None:
            #如果未提供权重,默认平均分配权重
            self._weights = [1 / n_datasets] * n_datasets

    def __iter__(self):
        return CombinedDatasetIterator(self._datasets, self._seed, self._weights)

    def __len__(self):
        return sum([len(ds) for ds in self._datasets])

网络结构

lagllama

class LagLlamaModel(nn.Module):
    def __init__(
        self,
        max_context_length: int,
        scaling: str,
        input_size: int,
        n_layer: int,
        n_embd: int,
        n_head: int,
        lags_seq: List[int],
        rope_scaling=None,
        distr_output=StudentTOutput(),
        num_parallel_samples: int = 100,
    ) -> None:
        super().__init__()
        self.lags_seq = lags_seq

        config = LTSMConfig(
            n_layer=n_layer,
            n_embd=n_embd,
            n_head=n_head,
            block_size=max_context_length,
            feature_size=input_size * (len(self.lags_seq)) + 2 * input_size + 6,
            rope_scaling=rope_scaling,
        )
        self.num_parallel_samples = num_parallel_samples

        if scaling == "mean":
            self.scaler = MeanScaler(keepdim=True, dim=1)
        elif scaling == "std":
            self.scaler = StdScaler(keepdim=True, dim=1)
        else:
            self.scaler = NOPScaler(keepdim=True, dim=1)
        self.distr_output = distr_output
        self.param_proj = self.distr_output.get_args_proj(config.n_embd)

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Linear(config.feature_size, config.n_embd),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=RMSNorm(config.n_embd),
            )
        )

主要是transformer里面首先是一个线性层,然后加了n_layer个Block,最后是RMSNorm,接下来解析Block的代码

在这里插入图片描述

Block

class Block(nn.Module):
    def __init__(self, config: LTSMConfig) -> None:
        super().__init__()
        self.rms_1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.rms_2 = RMSNorm(config.n_embd)
        self.mlp = MLP(config)

        self.y_cache = None

    def forward(self, x: torch.Tensor, is_test: bool) -> torch.Tensor:
        if is_test and self.y_cache is not None:
            # Only use the most recent one, rest is in cache
            x = x[:, -1:]

        x = x + self.attn(self.rms_1(x), is_test)
        y = x + self.mlp(self.rms_2(x))

        if is_test:
            if self.y_cache is None:
                self.y_cache = y  # Build cache
            else:
                self.y_cache = torch.cat([self.y_cache, y], dim=1)[
                    :, 1:
                ]  # Update cache
        return y

代码看到这里不太想继续看了,太多glunoTS库里面的函数了,我完全不熟悉这个库,看起来太痛苦了,还有很多的困惑,最大的困惑就是数据是怎么对齐的,怎么输入到Llama里面的,慢慢看吧

其他

来源
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1342670.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【每天一个早下班技巧】NPM发包流程

发包流程 step1&#xff1a;设置包发布地址 参考资料 // 1.在package.json中设置发布地址 "publishConfig":{"registry":"http://registry.npm.xxx.com" }//2.设置别名 alias ynpm"npm --registryhttp://registry.npm.xxx.com" ynp…

Typora使用PicGo+Gitee上传图片

Typora使用PicGoGitee上传图片 1.下载PicGo(国内镜像) https://mirrors.sdu.edu.cn/github-release/Molunerfinn_PicGo/ 点击PicGo-Setup-2.3.0-x64.exe &#xff08;64位安装&#xff09; 然后打开gitee&#xff08;没注册先注册&#xff09; 2.下载node.js插件 https:/…

电脑系统坏了用U盘重装系统教程

我们平时办公、学习都会用到电脑&#xff0c;如果电脑系统坏了&#xff0c;就会影响自己正常使用电脑&#xff0c;这时候就可以通过U盘来重装一个正常的操作系统。如果您不知道具体的重装操作步骤&#xff0c;那么可以参考下面小编分享的利用U盘快速完成操作系统重装的步骤介绍…

4.14 构建onnx结构模型-Min

前言 构建onnx方式通常有两种&#xff1a; 1、通过代码转换成onnx结构&#xff0c;比如pytorch —> onnx 2、通过onnx 自定义结点&#xff0c;图&#xff0c;生成onnx结构 本文主要是简单学习和使用两种不同onnx结构&#xff0c; 下面以 Min 结点进行分析 方式 方法一&…

【网络安全 | XCTF】2017_Dating_in_Singapore

正文 题目描述&#xff1a; 01081522291516170310172431-050607132027262728-0102030209162330-02091623020310090910172423-02010814222930-0605041118252627-0203040310172431-0102030108152229151617-04050604111825181920-0108152229303124171003-261912052028211407-0405…

华为无线AC内三层漫游配置详解

重要说明 1、在一台ac中实现三层漫游 2、ac和核心的互联vlan和ap的管理vlan是同一个广播域&#xff0c;可以不用配option 43 3、直接转发模式&#xff0c;ac上可以不起业务vlan&#xff0c;ac和核心交换机上可以只放行一个互联vlan 10 4、ac上要启两个vap魔板&#xff0c;两个…

Git 浅入浅出

前提 最近和同事分模块联合开发代码&#xff0c;自然而然就要用到 Git 管理代码&#xff1b;借此机会&#xff0c;对 Git 进行简单介绍。 Git 的特征 文件系统 我们都知道 Git 是个版本控制系统&#xff0c;但是如果你深入了解其原理&#xff0c;就不难发现它更像一个文件管…

区间DP详解,思路分析,OJ详解

文章目录 前言问题引入暴力枚举自下而上状态设计状态转移方程 区间DP的分析状态设计状态转移时间复杂度翻译成递推 OJ详解P1880 [NOI1995] 石子合并记忆化搜索版本递推版本 HDU Dire WolfMultiplication PuzzlePolygon 总结 前言 区间dp属于动态规划中一类比较好理解的问题&…

概率论相关题型

文章目录 概率论的基本概念放杯子问题条件概率与重要公式的结合独立的运用 随机变量以及分布离散随机变量的分布函数特点连续随机变量的分布函数在某一点的值为0正态分布标准化随机变量函数的分布 多维随机变量以及分布条件概率max 与 min 函数的相关计算二维随机变量二维随机变…

超级详细的YOLOV8教程

超级详细的YOLOV8教程 YOLOV8介绍1. 数据标记1.1 第一种为在网站上下载&#xff0c;1.2 第二种为在CVAT上自定义数据 2. 制作数据集3. 部署YOLOV8的代码3.1 远程部署3.1.1 项目下载3.1.2 修改代码3.1.2.1 训练模型3.1.2.1 验证模型 3.2 本地部署3.2.1 YOLOV8项目部署3.2.2 cuda…

2023总结与展望--Empirefree

今年一篇博客都没写过了&#xff0c;好像完全在忙在工作和生活上面了&#xff0c;珍惜自我&#xff0c;保持热情&#xff0c;2024对我好点 文章目录 &#x1f525;1. 年终总结1.1.学习工作计划1.2. 生活计划1.3 个人总结 &#x1f525;2. 未来展望 &#x1f525;1. 年终总结 1…

基于Java学生成绩管理系统设计与实现(源码+部署文档+报告)

博主介绍&#xff1a; ✌至今服务客户已经1000、专注于Java技术领域、项目定制、技术答疑、开发工具、毕业项目实战 ✌ &#x1f345; 文末获取源码联系 &#x1f345; &#x1f447;&#x1f3fb; 精彩专栏 推荐订阅 &#x1f447;&#x1f3fb; 不然下次找不到 Java项目精品实…

两阶段提交协议

数据的强一致性 要么都修改 要么就都不修改。 不同的实体和过程 领导者和参与者、表决阶段和提交阶段 过程 一个不同意 提交就终止 存在的问题和解决的方案 如果一个领导者或者参与者的状态机中有阻塞状态&#xff0c;那么系统必须等他完成才能执行&#xff0c;这样就会…

【并发编程篇】线程安全问题_—_ConcurrentHashMap

文章目录 &#x1f354;情景引入&#x1f339;报错了&#xff0c;解决方案 &#x1f354;情景引入 我们运行下面的代码 package org.example.unsafe;import java.util.HashMap; import java.util.Map; import java.util.UUID;public class MapTest {public static void main(…

SpringBoot单点登录认证系统MaxKey(附开源项目地址)

1 项目介绍 MaxKey 单点登录认证系统&#xff0c;谐音马克思的钥匙寓意是最大钥匙&#xff0c;支持 OAuth 2.x/OpenID Connect、SAML 2.0、JWT、CAS、SCIM 等标准协议&#xff0c;提供简单、标准、安全和开放的用户身份管理(IDM)、身份认证(AM)、单点登录(SSO)、RBAC 权限管理…

如何把握品牌新五感,打造小红书品牌

随着社会经济的发展&#xff0c;市场的进步&#xff0c;以及人们思维方式的改变。年轻人面对市场&#xff0c;面对营销&#xff0c;关注点也在发生着改变。那什么是小红书品牌新五感&#xff0c;如何把握品牌新五感&#xff0c;打造小红书品牌&#xff01; 一、品牌五感是什么 …

查找dll的开放函数以及dll的依赖dll

1.进入一个vs的cmd窗口 2.dumpbin /exports XXX.dll&#xff0c;分析 XXX.dll 中有哪些函数。 例如查询: C:\Users\levi0\Desktop\testPro\NewCSDll\NewCSDll\bin\x64\Debug\GBRAnalyze.dll 3. dumpbin /dependents 文件名&#xff08;带路径&#xff09;命令&#xff0c;回车&…

servlet+jdbc实现用户注册功能

一、需求 在Servlet中可以使用JDBC技术访问数据库&#xff0c;常见功能如下&#xff1a; 查询DB数据&#xff0c;然后生成显示页面&#xff0c;例如&#xff1a;列表显示功能。接收请求参数&#xff0c;然后对DB操作&#xff0c;例如&#xff1a;注册、登录、修改密码等功能。…

Linux中磁盘管理与文件系统

目录 一.磁盘基础&#xff1a; 1.磁盘的结构&#xff1a; 2.硬盘的数据结构&#xff1a; 3.硬盘存储容量 &#xff1a; 4.硬盘接口类型&#xff1a; 二.MBR与磁盘分区&#xff1a; 1.MBR的概念&#xff1a; 2.硬盘的分区&#xff1a; 为什么分区&#xff1a; 2.表示&am…

【PHP】B/S手术室麻醉信息管理系统源码

手术麻醉临床信息系统全面覆盖从患者入院&#xff0c;经过术前、术中、术后&#xff0c;直至出院的全过程。通过与相关医疗仪器的设备集成&#xff0c;不但可以轻松集成手术室传统监护设备如监护仪、麻醉机、呼吸机&#xff0c;也能与血气分析仪等设备对接&#xff0c;快速获取…