Chronos:学习时间序列的大语言模型(代码解析)

news2024/11/27 11:44:07

前言

  • 《Chronos: Learning the Language of Time Series》原文地址,Github开源代码地址
  • Chronos:学习时间序列的大语言模型(论文解读)CSDN地址
  • GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习过程中精读的一些论文,并对其进行了中文翻译。还有部分最佳示例教程
  • 如果有帮助到大家,请帮忙点亮Star,也是对译者莫大的鼓励,谢谢啦~
  • 本文代码已同步至项目Some-Paper-CN,后续可能会根据热度发布使用LoRA微调Chronos模型教程,浅浅期待一下吧~

先验知识

  • 建议先阅读Chronos论文解读篇,对大致原理有所了解,阅读代码效果会更好。
  • 在论文解读篇中,我们已经知道了Chronos是基于Google的开源模型T5(Huggingface)。因受篇幅影响,有关T5模型的解析不在本次讨论范围内,感兴趣的小伙伴可以去查询相关资料。
  • 论文基于Transformers框架,在阅读代码前,最好有一定Transformers库的基础知识。
  • 虽然本文模型为时间序列模型,但不管是在模型架构、训练方式还是数据组织上都与大语言模型几乎一致,在阅读代码前,最好有一定大语言模型领域的知识,比如术语tonkentokenizer

代码解析

  • 将开源代码从Github上下载到本地,关键文件在chronos-forecasting/src/chronos下,chronos.py文件。
  • ChronosConfig用于加载模型参数(注意!是参数不是权重),类ChronosTokenizer用于加载模型Tokenizer,类ChronosModel用于根据模型参数构建模型。上述类为Transformers库基础类,这里不多赘述。
  • 论文中的核心在类MeanScaleUniformBins用于数据均值缩放和量化分箱,类ChronosPipeline用于构架数据预测管道。

MeanScaleUniformBins

class MeanScaleUniformBins(ChronosTokenizer):
    def __init__(
        self, low_limit: float, high_limit: float, config: ChronosConfig
    ) -> None:
        self.config = config
        # 线性平分向量torch.linspace(start, end, steps)
        self.centers = torch.linspace(
            low_limit,
            high_limit,
            config.n_tokens - config.n_special_tokens - 1,
        )
        # 首尾元素分别为-1e20、1e20
        # self.centers[1:]除第1个元素外的所有元素
        # self.centers[:-1]除最后1个元素外的所有元素
        # (self.centers[1:] + self.centers[:-1]) / 2表示相邻元素平均值
        self.boundaries = torch.concat(
            (
                torch.tensor([-1e20], device=self.centers.device),
                (self.centers[1:] + self.centers[:-1]) / 2,
                torch.tensor([1e20], device=self.centers.device),
            )
        )

    def input_transform(
        self, context: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, length = context.shape

        if length > self.config.context_length:
            # 保留最后context_length个元素
            context = context[..., -self.config.context_length :]

        # 空值的反向布尔值
        attention_mask = ~torch.isnan(context)
        # context绝对值和attention_mask的点积,除以attention_mask的和
        scale = torch.nansum(
            torch.abs(context) * attention_mask, dim=-1
        ) / torch.nansum(attention_mask, dim=-1)
        # scale是0或空值设为1.0
        scale[~(scale > 0)] = 1.0
        # 将context按scale缩放
        scaled_context = context / scale.unsqueeze(dim=-1)
        # torch.bucketize根据边界值将输入映射到相应bucket(桶)中
        token_ids = (
            torch.bucketize(
                input=scaled_context,
                boundaries=self.boundaries,
                right=True,
            )
            + self.config.n_special_tokens
        )
        # 不需要关注的地方使用padding
        token_ids[~attention_mask] = self.config.pad_token_id

        # 如果需要在末尾添加eos符
        if self.config.use_eos_token:
            eos_tokens = torch.full(
                (batch_size, 1), fill_value=self.config.eos_token_id
            )
            token_ids = torch.concat((token_ids, eos_tokens), dim=1)
            # mask置为true
            eos_mask = torch.full((batch_size, 1), fill_value=True)
            attention_mask = torch.concat((attention_mask, eos_mask), dim=1)

        return token_ids, attention_mask, scale

    def output_transform(
        self, samples: torch.Tensor, scale: torch.Tensor
    ) -> torch.Tensor:
        # 将scale扩展两个维度
        scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
        # 将值限制在0和centers长度间,确保索引值不超出centers
        indices = torch.clamp(
            samples - self.config.n_special_tokens,
            min=0,
            max=len(self.centers) - 1,
        )
        # 返回在原始context缩放级别下分桶值
        return self.centers[indices] * scale_unsqueezed
  • low_limithigh_limit包含在模型参数中,根据论文分别为-1515

  • input_transform函数中scale = torch.nansum(torch.abs(context) * attention_mask, dim=-1) / torch.nansum(attention_mask, dim=-1)看上去非常复杂,实际上在没有空值的情况下,相当于对序列求平均值。

  • input_transform函数中分箱函数torch.bucketize的使用可以参考官方文档。

  • input_transform函数中空值使用padding填充,并使用mask进行遮掩,是大语言模型训练的常用操作。

  • 在论文中,作者表示为了保持与大语言模型训练方式保持一致,会在序列结束后放置eos标识符,所以模型参数use_eos_token是为True的。

  • output_transform函数是input_transform函数的反操作,需要注意的是torch.clamp函数,确保token_id在词表中,否则就无法反归一化得到正常的值了。

ChronosPipeline

  • from_pretrained函数用于加载模型预训练权重,这里不在过多赘述,关键在于predict函数。
    def predict(
        self,
        context: Union[torch.Tensor, List[torch.Tensor]],
        prediction_length: Optional[int] = None,
        num_samples: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        limit_prediction_length: bool = True,
    ) -> torch.Tensor:
        """
        Get forecasts for the given time series.

        Parameters
        ----------
        context
            Input series. This is either a 1D tensor, or a list
            of 1D tensors, or a 2D tensor whose first dimension
            is batch. In the latter case, use left-padding with
            ``torch.nan`` to align series of different lengths.
        prediction_length
            Time steps to predict. Defaults to what specified
            in ``self.model.config``.
        num_samples
            Number of sample paths to predict. Defaults to what
            specified in ``self.model.config``.
        temperature
            Temperature to use for generating sample tokens.
            Defaults to what specified in ``self.model.config``.
        top_k
            Top-k parameter to use for generating sample tokens.
            Defaults to what specified in ``self.model.config``.
        top_p
            Top-p parameter to use for generating sample tokens.
            Defaults to what specified in ``self.model.config``.
        limit_prediction_length
            Force prediction length smaller or equal than the
            built-in prediction length from the model. True by
            default. When true, fail loudly if longer predictions
            are requested, otherwise longer predictions are allowed.

        Returns
        -------
        samples
            Tensor of sample forecasts, of shape
            (batch_size, num_samples, prediction_length).
        """
        context_tensor = self._prepare_and_validate_context(context=context)

        if prediction_length is None:
            prediction_length = self.model.config.prediction_length

        if prediction_length > self.model.config.prediction_length:
            msg = (
                f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
                "The quality of longer predictions may degrade since the model is not optimized for it. "
            )
            if limit_prediction_length:
                msg += "You can turn off this check by setting `limit_prediction_length=False`."
                raise ValueError(msg)
            warnings.warn(msg)

        predictions = []
        remaining = prediction_length

        while remaining > 0:
            # 根据MeanScaleUniformBins类对数据进行缩放和分箱
            token_ids, attention_mask, scale = self.tokenizer.input_transform(
                context_tensor
            )
            # 输入模型得到结果
            samples = self.model(
                token_ids.to(self.model.device),
                attention_mask.to(self.model.device),
                min(remaining, self.model.config.prediction_length),
                num_samples,
                temperature,
                top_k,
                top_p,
            )
            prediction = self.tokenizer.output_transform(
                samples.to(scale.device), scale
            )

            predictions.append(prediction)
            remaining -= prediction.shape[-1]
			
            # 判断是否预测完
            if remaining <= 0:
                break
			# 拼接操作
            context_tensor = torch.cat(
                [context_tensor, prediction.median(dim=1).values], dim=-1
            )

        return torch.cat(predictions, dim=-1)
  • 作者建议将prediction length保持在64以下,因为模型没有针对较长的预测长度进行优化,因此预测质量可能会下降。
  • 预测过程为:根据MeanScaleUniformBins类中input_transform函数对数据进行缩放和分箱,得到token_id、掩码矩阵attention_mask, 均值scale;将token_id和掩码矩阵attention_mask输入模型,得到输出samples。根据MeanScaleUniformBins类中output_transform函数和均值scale将输出samples反归一化得到实际值。
  • remaining变量用于检验prediction length是否全部预测完。

left_pad_and_stack_1D

  • 上述代码中函数predict调用了_prepare_and_validate_context函数,本质是left_pad_and_stack_1D函数。
def left_pad_and_stack_1D(tensors: List[torch.Tensor]):
    # tensors中最长元素的长度
    max_len = max(len(c) for c in tensors)
    padded = []
    # 遍历tensors中元素
    for c in tensors:
        assert isinstance(c, torch.Tensor)
        # c为一维张量
        assert c.ndim == 1
        # 填充torch.nan
        padding = torch.full(
            size=(max_len - len(c),), fill_value=torch.nan, device=c.device
        )
        # 拼接(c长度被扩展为max_len),并添加到列表padded中
        padded.append(torch.concat((padding, c), dim=-1))
    # 将padded列表中的所有元素沿着新维度折叠,形成二维张量
    return torch.stack(padded)
  • 该函数是大语言模型训练过程中为了补齐长度做的操作,如果不理解也没事,只要明白在干什么就行。

测试Demo

  • 如果想要进一步了解代码,还是希望大家用一个轻量的测试Demo从头到尾Debug一下。
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-tiny",
    device_map="cpu",
    torch_dtype=torch.float16,
)

df = pd.read_csv("AirPassengers.csv")

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
prediction_length = 12
forecast = pipeline.predict(
    context,
    prediction_length,
    num_samples=20,
    temperature=1.0,
    top_k=50,
    top_p=1.0,
) # forecast shape: [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
plt.legend()
plt.grid()
plt.show()
  • 预测结果效果图

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1659586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue多选功能

废话不多说&#xff0c;直接上代码&#xff01;&#xff01;&#xff01; <template><div class"duo-xuan-page"><liv-for"(item, index) in list":key"index"click"toggleSelection(item)":class"{ active: sel…

加密“发射台”:未来通信的新模式

随着区块链技术的飞速发展&#xff0c;加密“发射台”作为一种新兴的安全通信工具&#xff0c;正逐渐受到关注。本文将从专业角度深入探讨加密“发射台”的概念、原理、应用场景及其未来发展趋势&#xff0c;以期为读者提供有深度和逻辑性的思考。 一、加密“发射台”的概念与…

Chromium 调试指南2024 Windows11篇-准备调试工具(二)

1. 前言 在上篇文章《Chromium 调试指南 2024 &#xff08;一&#xff09;》中&#xff0c;我们简单的概述了下调试的作用。 工欲善其事必先利其器&#xff0c;现在我们开始准备你的调试工具。 2. 准备你的调试工具 这里根据Chromium官方开发指南的推荐使用Visual Studio 2…

360极速浏览器X全新Chromium内核极致顺滑,绿色便携版 v22.3.1002.64

01 软件介绍 360极速浏览器X是一款基于Chromium 95的高级双核浏览器&#xff0c;支持IE内核&#xff0c;并优化了用户体验与性能。包括无广告弹窗&#xff0c;新增的阅读模式&#xff0c;个性化标签页壁纸&#xff0c;以及专业导航功能&#xff0c;旨在提供更快、更高效的浏览…

防火墙技术基础篇:解析防火墙应用层代理概念及功能

防火墙技术基础篇&#xff1a;解析防火墙应用层代理概念及功能 1 应用层代理的概念 应用层代理&#xff08;Application Proxy&#xff09;&#xff1a;防火墙应用层代理是网络安全领域中的一种重要技术&#xff0c;工作在OSI模型的第七层&#xff0c;即应用层。它通过代理服…

初步理解ECC

理解一下公钥加密的概念 公钥加密需要一对钥匙&#xff0c;公钥和私钥&#xff0c;公钥可以公开&#xff0c;而私钥不能泄露。 那就可以用公钥给明文加密 而只有私钥才能进行解密 而要想实现这个过程&#xff0c;就需要满足两个点 1&#xff1a;加密简单易行 2&#xff1a;解…

经常睡不好觉?试试用上华为手环9新升级的睡眠监测功能

睡眠问题是不是经常困扰着你呢&#xff1f;听说&#xff0c;华为手环9的睡眠监测功能升级了&#xff0c;无论是入睡前、睡眠中还是睡醒后&#xff0c;都能够帮助我们改善睡眠&#xff0c;让我们告别糟糕的睡眠质量&#xff01; 睡觉前&#xff0c;打开华为手环9的睡眠模式&…

Linux 操作系统TCP、UDP

1、TCP服务器编写流程 头文件&#xff1a; #include <sys/socket.h> 1.1 创建套接字 函数原型&#xff1a; int socket(int domain, int type, int protocol); 参数&#xff1a; domain: 网域 AF_INET &#xff1a; IPv4 AF_INET6 &a…

RockChip Android13 添加/删除ListPreference方法

概述: 本章将讲述在Android添加或删除ListPreference的几种方法,并以EthernetSettingsActivity为例,添加/删除一项ListPreference: 默认效果图: 添加后效果图: 方法一: 1、全部添加xml 在Activity类中使用addPreferencesFromResource()方法解析XML文件并添加Prefere…

OpenHarmony 实战开发——南向统一编译的docker镜像来了

由于我自己的南向设备开发平台的需求&#xff0c;我将当前几个不同的 docker 镜像版本进行了整合&#xff0c;经过一段时间的攻关和验证&#xff0c;目前整合已完成&#xff0c;新版本的 Dockerfile 如下&#xff0c;这个不是公共需求&#xff0c;所以没有提交主干&#xff0c;…

Linux学习笔记4---点亮LED灯(汇编裸机)

本系统学习利用的是正点原子的阿尔法mini开发板&#xff0c;本系列的学习笔记也是按照正点原子的教程进行学习&#xff0c;但并不是利用虚拟机进行开发&#xff0c;而是使用Windows下的子系统WSL进行学习。 因为 Cortex-A 芯片一上电 SP 指针还没初始化&#xff0c;C 环境还没准…

网络安全之弱口令与命令爆破(下篇)(技术进阶)

目录 一&#xff0c;什么是弱口令&#xff1f; 二&#xff0c;为什么会产生弱口令呢&#xff1f; 三&#xff0c;字典的生成 四&#xff0c;九头蛇&#xff08;hydra&#xff09;弱口令爆破工具 1&#xff0c;破解ssh登录密码 2&#xff0c;破解windows登录密码 3&#xf…

新增分类——后端

实现功能&#xff1a; 代码开发逻辑&#xff1a; 页面发送ajax请求&#xff0c;将新增分类窗口输入的数据以json形式提交到服务端服务端Controller接收页面提交的数据并调用Service将数据进行保存Service调用Mapper操作数据库&#xff0c;保存数据 代码实现&#xff1a; Con…

测试项目实战——安享理财1(测试用例)

说明&#xff1a; 1.访问地址&#xff1a; 本项目实战使用的是传智播客的安享理财项目&#xff08;找了半天这个项目能免费用且能够满足测试实战需求&#xff09; 前台&#xff1a;http://121.43.169.97:8081/ 后台&#xff1a;http://121.43.169.97:8082/ &#xff08;点赞收藏…

电脑ip地址设置成什么比较好

随着信息技术的快速发展&#xff0c;IP地址已成为电脑在网络世界中的“身份证”。它不仅是电脑在网络中进行通信的基础&#xff0c;也直接关系到网络连接的稳定性、安全性和效率。然而&#xff0c;面对众多IP地址设置选项&#xff0c;许多用户可能会感到困惑。那么&#xff0c;…

C#进阶-OleDb操作Excel和数据库

在C#编程中&#xff0c;使用OleDb可以方便地实现对Excel文件和数据库的操作。本文探讨了在C#中使用OleDb技术操作Excel和数据库的策略。文章详述了OleDb的定义、配置环境的步骤&#xff0c;并通过实际代码示例演示了如何高效读写Excel文件和交互数据库。文中还评估了OleDb技术的…

5到15秒片头音乐200款,30秒片头音效音乐大全

一、素材描述 本套音乐音效素材&#xff0c;大小2.88G&#xff0c;13个压缩文件。 二、素材目录 200个5到15秒的片头音乐.zip 30秒片头-1.zip 30秒片头-2.zip 30秒片头-3.zip 30秒片头-4.zip 30秒片头-5.zip 30秒片头-6.zip 30秒片头-7.zip 30秒片头-8.zip 30秒片头…

武汉星起航电子商务有限公司助力打造一站式跨境电商服务

在全球电商市场蓬勃发展的当下&#xff0c;武汉星起航电子商务有限公司以其独特的经营模式和全方位的服务体系&#xff0c;成为亚马逊自营领域的佼佼者。自2017年专注于亚马逊自营以来&#xff0c;武汉星起航凭借丰富的经验和深入的市场洞察&#xff0c;为合作伙伴提供了一站式…

生信软件17 - 基于fasta文件的捕获探针设计工具catch

catch是broad研究所开发的一款用于设计捕获探针的python软件。 1. 软件安装 适用于Linux / windows等&#xff0c;安装要求Python≥3.8 | NumPy≥1.22 | SciPy≥1.8.0 # github安装 git clone https://github.com/broadinstitute/catch.git cd catch pip install -e .# coon…

react项目中封装一个通用的边界Boundary

# Boundary 通用的边界,同时是一个Suspense 和一个 ErrorBoundary 正常情况不直接用,使用一下几个封装好的: -Boundary.FullSizeLoading: 占满父容器全部高度,居中显示等待动画; -Boundary.Loading: 占满一行,显示一个普通尺寸的等待动画; -Boundary.Blank: 什么都不显示…