UniTS代码解读

news2024/9/9 5:35:13

除了时间序列的token是patch了再value embedding,其余的都是普通的扩维embedding,都是nn.parameter可训练的

在UniTS模型中,不同类型的token(如mask token、CLS token、prompt token和时间序列token)的embedding方式如下:

  1. Mask Token的Embedding方式
    • Mask token的embedding是通过一个线性层(nn.Linear)实现的,其权重参数在模型初始化时被随机初始化。
    • 具体代码在Model类的初始化函数中,通过nn.ParameterDict来存储不同数据集的mask token:
      self.mask_tokens = nn.ParameterDict({})
      
      这个应该就是mask token embedding的方式,就是这样
                self.mask_tokens[dataset_name] = torch.zeros(
                    1, configs_list[i][1]['enc_in'], 1, args.d_model)
  • prepare_prompt函数中,mask token被重复并应用到输入序列中:
    this_mask_prompt = task_prompt.repeat(x.shape[0], 1, task_prompt_num, 1)
    
  1. CLS Token的Embedding方式

    • CLS token的embedding方式与mask token类似,也是通过一个线性层实现的。
    • 具体代码在Model类的初始化函数中,通过nn.ParameterDict来存储不同任务的CLS token:
      self.cls_tokens = nn.ParameterDict({})
      
    • classification函数中,CLS token被重复并应用到输入序列中:
      this_function_prompt = task_prompt.repeat(x.shape[0], 1, 1, 1)
      
  2. Prompt Token的Embedding方式

    • Prompt token的embedding方式与CLS token类似,也是通过一个线性层实现的。
    • 具体代码在Model类的初始化函数中,通过nn.ParameterDict来存储不同数据集的prompt token:
      self.prompt_tokens = nn.ParameterDict({})
      
    • prepare_prompt函数中,prompt token被重复并应用到输入序列中:
      this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)
      
  3. 时间序列Token的Embedding方式

    • 时间序列token的embedding方式是通过PatchEmbedding类实现的,该类将时间序列数据分割成固定长度的patch,并应用线性层进行embedding。
    • 具体代码在Model类的tokenize函数中:
      x = self.patch_embeddings(x)
      
    • PatchEmbedding类中的value_embedding线性层用于将patch转换为embedding向量:
      self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
      

总结来说,UniTS模型中不同类型的token的embedding方式都是通过线性层实现的,具体实现细节在Model类的初始化函数和prepare_prompt函数中有所体现。

在代码中,mask token 主要在 pretraining 方法中被使用。具体来说,mask token 在以下步骤中被用到:

  1. 随机掩码生成

    • choose_masking 方法根据 right_prob 决定使用哪种掩码方法(随机掩码或右掩码)。
    • random_maskingright_masking 方法生成掩码,其中掩码表示哪些位置应该被掩码(即用 mask token 替换)。
  2. 掩码应用

    • pretraining 方法中,掩码被应用到输入数据 x 上。
    • mask_repeat 是一个与 x 形状相同的掩码矩阵,其中掩码位置为1,非掩码位置为0。
    • x 被乘以 mask_repeat,使得掩码位置被 mask_token 替换。
  3. 掩码填充

    • init_full_input 是包含 prompt tokensx 的张量。
    • init_mask_prompt 是通过 prompt2forecat 线性层生成的掩码提示。
    • x 被乘以 mask_repeat,使得掩码位置被 init_mask_prompt 替换。
  4. 位置嵌入

    • x 被加上位置嵌入,以考虑序列的顺序信息。
  5. 前向传播

    • x 被传递到 backbone(即 Transformer 的基本块)进行前向传播。
  6. 预测输出

    • mask_dec_outcls_dec_out 是通过 forecast_headcls_head 生成的预测输出。
    • mask_dec_out 是掩码部分的预测输出。
    • cls_dec_out 是分类部分的预测输出。

以下是相关代码片段:

def pretraining(self, x, x_mark, task_id, enable_mask=False):
    dataset_name = self.configs_list[task_id][1]['dataset']
    task_data_name = self.configs_list[task_id][0]
    prefix_prompt = self.prompt_tokens[dataset_name]
    mask_token = self.mask_tokens[dataset_name]
    cls_token = self.cls_tokens[task_data_name]

    seq_len = x.shape[1]
    x, means, stdev, n_vars, padding = self.tokenize(x)
    seq_token_len = x.shape[-2]

    # append prompt tokens
    x = torch.reshape(
        x, (-1, n_vars, x.shape[-2], x.shape[-1]))
    # prepare prompts
    this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)

    if enable_mask:
        mask = self.choose_masking(x, self.right_prob,
                                   self.min_mask_ratio, self.max_mask_ratio)
        mask_repeat = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
        mask_repeat = mask_repeat.repeat(1, x.shape[1], 1, x.shape[-1])
        x = x * (1-mask_repeat) + mask_token * mask_repeat  # todo

        init_full_input = torch.cat((this_prompt, x), dim=-2)
        init_mask_prompt = self.prompt2forecat(
            init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)
        # keep the unmasked tokens and fill the masked ones with init_mask_prompt.
        x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat
        x = x + self.position_embedding(x)
        mask_seq = self.get_mask_seq(mask, seq_len+padding)
        mask_seq = mask_seq[:, :seq_len]
    this_function_prompt = cls_token.repeat(x.shape[0], 1, 1, 1)
    x = torch.cat((this_prompt, x, this_function_prompt), dim=2)

    x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)

    if enable_mask:
        mask_dec_out = self.forecast_head(
            x[:, :, :-1], seq_len+padding, seq_token_len)
        mask_dec_out = mask_dec_out[:, :seq_len]
        # De-Normalization from Non-stationary Transformer
        mask_dec_out = mask_dec_out * \
            (stdev[:, 0, :].unsqueeze(1).repeat(
                1, mask_dec_out.shape[1], 1))
        mask_dec_out = mask_dec_out + \
            (means[:, 0, :].unsqueeze(1).repeat(
                1, mask_dec_out.shape[1], 1))
        cls_dec_out = self.cls_head(x, return_feature=True)
        # detach grad of the forecasting on tokens
        fused_dec_out = torch.cat(
            (cls_dec_out, x[:, :, self.prompt_num:-1].detach()), dim=2)
        cls_dec_out = self.pretrain_head(
            fused_dec_out, seq_len+padding, seq_token_len)
        cls_dec_out = cls_dec_out[:, :seq_len]
        cls_dec_out = cls_dec_out * \
            (stdev[:, 0, :].unsqueeze(1).repeat(
                1, cls_dec_out.shape[1], 1))
        cls_dec_out = cls_dec_out + \
            (means[:, 0, :].unsqueeze(1).repeat(
                1, cls_dec_out.shape[1], 1))

        return cls_dec_out, mask_dec_out, mask_seq
    else:
        return cls_dec_out

总结来说,mask token 在预训练过程中用于生成掩码序列,并在模型训练时帮助模型学习如何预测被掩码的位置。

CLS token被重复并应用到输入序列中,这怎么理解,为什么要这么做

在BERT模型中,CLS(Classification)token是一个特殊的token,它在输入序列的开头添加。CLS token的主要目的是作为整个输入序列的表示,用于分类任务。在训练过程中,CLS token的输出会被用来预测任务相关的标签。

重复CLS token并应用到输入序列中

在训练过程中,BERT模型会接收一个包含多个输入序列的批次。每个输入序列都会添加一个CLS token,并且这个CLS token会被重复应用到整个输入序列中。这样做的原因是:

  1. 统一表示:通过将CLS token重复应用到整个输入序列中,我们可以确保每个输入序列都有一个统一的表示,这个表示可以用于分类任务。

  2. 简化模型结构:在模型的前向传播过程中,我们只需要处理一个输入序列,而不需要为每个输入序列单独处理CLS token。

  3. 提高模型性能:通过使用CLS token,我们可以更好地捕捉整个输入序列的信息,从而提高模型的性能。

##代码示例

以下是一个简单的代码示例,展示了如何在BERT模型中添加CLS token:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设我们有一个简单的BERT模型
class SimpleBERT(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, num_heads, dropout):
        super(SimpleBERT, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_size, num_heads, hidden_size, dropout),
            num_layers
        )
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 添加CLS token
        cls_token = torch.tensor([0]).unsqueeze(0).to(x.device)
        x = torch.cat([cls_token, x], dim=0)
        x = self.embedding(x)
        x = self.dropout(x)
        x = self.encoder(x)
        cls_token_output = x[0]  # 获取CLS token的输出
        x = self.fc(x)
        return x, cls_token_output

# 创建模型实例
model = SimpleBERT(vocab_size=10000, hidden_size=512, num_layers=6, num_heads=8, dropout=0.1)

# 假设我们有一个输入序列
input_seq = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])  # 示例输入序列

# 前向传播
outputs, cls_token_output = model(input_seq)

# 计算损失
loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), input_seq.view(-1))

在这个例子中,我们首先定义了一个简单的BERT模型,然后创建了一个输入序列。在模型的前向传播过程中,我们添加了一个CLS token,并将其重复应用到整个输入序列中。最后,我们获取了CLS token的输出,并计算了交叉熵损失,用于训练模型。

mask token是在时间序列TokenEmbedding后添加时间序列上的还是没有embedding添加的

在代码中,mask token 是在时间序列 TokenEmbedding 后添加的。具体来说,mask token 是在 prepare_prompt 函数中添加的。这个函数根据不同的任务类型(如 forecast, classification, imputation, anomaly_detection)对输入的时间序列进行处理,并在适当的位置添加 mask token

以下是 prepare_prompt 函数的关键部分,展示了 mask token 是如何添加的:

def prepare_prompt(self, x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name=None, mask=None):
    x = torch.reshape(
        x, (-1, n_vars, x.shape[-2], x.shape[-1]))
    # append prompt tokens
    this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)

    if task_name == 'forecast':
        this_mask_prompt = task_prompt.repeat(
            x.shape[0], 1, task_prompt_num, 1)
        init_full_input = torch.cat(
            (this_prompt, x, this_mask_prompt), dim=-2)
        init_mask_prompt = self.prompt2forecat(init_full_input.transpose(
            -1, -2), init_full_input.shape[2]-prefix_prompt.shape[2]).transpose(-1, -2)
        this_function_prompt = init_mask_prompt[:, :, -task_prompt_num:]
        x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
        x[:, :, self.prompt_num:] = x[:, :, self.prompt_num:] + \
            self.position_embedding(x[:, :, self.prompt_num:])
    elif task_name == 'classification':
        this_function_prompt = task_prompt.repeat(x.shape[0], 1, 1, 1)
        x = x + self.position_embedding(x)
        x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
    elif task_name == 'imputation':
        # fill the masked parts with mask tokens
        # for imputation, masked is 0, unmasked is 1, so here to reverse mask
        mask = 1-mask
        mask = mask.permute(0, 2, 1)
        mask = self.mark2token(mask)
        mask_repeat = mask.unsqueeze(dim=-1)

        mask_token = task_prompt
        mask_repeat = mask_repeat.repeat(1, 1, 1, x.shape[-1])
        x = x * (1-mask_repeat) + mask_token * mask_repeat

        init_full_input = torch.cat((this_prompt, x), dim=-2)
        init_mask_prompt = self.prompt2forecat(
            init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)
        # keep the unmasked tokens and fill the masked ones with init_mask_prompt.
        x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat
        x = x + self.position_embedding(x)
        x = torch.cat((this_prompt, x), dim=2)
    elif task_name == 'anomaly_detection':
        x = x + self.position_embedding(x)
        x = torch.cat((this_prompt, x), dim=2)

    return x

prepare_prompt 函数中,根据不同的任务类型,mask token 被添加到时间序列 TokenEmbedding 的不同位置。例如,在 forecast 任务中,mask token 被添加到时间序列的末尾;在 imputation 任务中,mask token 被添加到被掩码的位置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961654.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这几个高级爬虫软件和插件真的强!

亮数据(Bright Data) 亮数据是一款强大的数据采集工具,以其全球代理IP网络和强大数据采集技术而闻名。它能够轻松采集各种网页数据,包括产品信息、价格、评论和社交媒体数据等。 网站:https://get.brightdata.com/we…

ubuntu安装并配置flameshot截图软件

参考:flameshot key-bindins 安装 sudo apt install flameshot自定义快捷键 Settings->Keyboard->View and Customize Shortcuts->Custom Shortcuts,输入该快捷键名称(自定义),然后输入command(…

RFID物流智能锁在物流锁控领域的意义与应用

在当今全球化和电子商务迅速发展的时代,物流行业作为经济的重要支撑,面临着日益增长的安全、效率和管理需求。物流锁控作为保障货物在运输过程中安全与完整的关键环节,传统的机械锁和简单电子锁已经难以满足现代物流复杂多变的业务场景。 一、…

前缀表达式(波兰式)和后缀表达式(逆波兰式)的计算方式

缀是指操作符。 1. 前缀表达式(波兰式) (1)不需用括号; (2)不用考虑运算符的优先级; (3)操作符置于操作数的前面。(如 3 2 ) 1.1 中…

3.5.3、查找和排序算法-插入类排序和选择类排序

术语说明 稳定:如果a原本在b前面,而ab,排序之后a仍然在b的前面; 不稳定:如果a原本在b的前面,而ab,排序之后a可能会出现在b的后面; 例如:数组{1,2,3,3,4,7,6}。如果排序后,两个3的位…

【嵌入式之RTOS】死锁问题详解

目录 一、什么是死锁 二、产生死锁的四个必要条件 三、避免死锁的方法 四、实际应用中的考虑 一、什么是死锁 死锁(Deadlock)是多任务或多线程环境中一个常见的问题,尤其是在实时操作系统(RTOS)中,如果…

kvm虚拟化平台部署

kvm虚拟化平台部署 kvm概念简介 kvm自linux2.6版本以后就整合到内核中,因此可以看做是一个原生架构. kvm虚拟化架构 硬件底层提供物理层面的硬件支持 linux(host),就相当于这个架构中的宿主机,上面运行了多个虚拟机。…

替换后端国外身份目录服务,宁盾身份域管接管FileNet助力国产化升级

IBM FileNet 是一款优秀的企业内容管理解决方案,为客户提供了领先的文档管理和流程管理集成环境,被大量企业所采用。FileNet 需要使用企业级的目录服务器(LDAP)作为其用户管理系统,满足其认证和授权的需求。对于 LDAP …

最高200万!苏州成都杭州的这些AI政策补贴,你拿到了吗?

随着全球人工智能技术的迅猛发展,地方政府纷纷出台相关政策以抢占未来科技的制高点。苏州 成都 杭州这三个城市更是推出了一系列AI政策补贴,旨在通过多方面支持,推动本地AI产业的发展。本文将带你了解目前不完全统计到的苏州 成都 杭州三地AI…

【Vulnhub系列】Vulnhub_pipe 靶场渗透(原创)

【Vulnhub系列靶场】Vulnhub-pipe 靶场渗透 原文转载已经过授权 原文链接:Lusen的小窝 - 学无止尽,不进则退 (lusensec.github.io) 一、环境配置 1、解决IP扫描不到问题 2、打开虚拟机,并修改网络连接模式为【NAT】即可 二、信息收集 1…

Python实战——轻松实现动态网页爬虫(附详细源码)

大家好&#xff0c;我是东眠的鱼&#xff0c;专注原创&#xff0c;致力于用浅显易懂的语言分享爬虫、数据分析及可视化等干货&#xff0c;希望人人都能学到新知识。<文末附带精品籽料哦&#xff0c;也可以和博主一起学Python呀&#xff01;> 项目背景 有同学自学爬虫时…

前端vue3 巧妙的checkbox 选中框样式

我们 做前端页面交互效果的时候 我们会使用到 checkbox 复选框 做一些交互的效果 我是用的是 nut-ui 组件库中的 checkbox 组件 类似于这样的选中效果 假如 二选一的那种 可以 这样写 交互好看 而不是单纯的 checkbox 框 这里我就不使用 gif 图片了 大家应该都可以看懂的 …

A股继续震荡下行,成交量继续一蹶不振。

A股继续震荡下行&#xff0c;成交量继续一蹶不振。今天的A股&#xff0c;让人揪心不已&#xff0c;你们知道是为什么吗&#xff1f;盘面上出现1个重要信号&#xff0c;一起来看看&#xff1a; 1、今天两市低开低走&#xff0c;向下回补了2867点的缺口&#xff0c;让人揪心不已。…

计算机毕业设计选题推荐-基于司机信用评价的货运管理系统-Java/Python项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

CUDA_Occupancy_Calculator计算公式

CUDA_Occupancy_Calculator计算公式

6 Java的基本程序设计结构(基本语法5)- 面向对象进阶

文章目录 面向对象进阶一、 static 静态1 静态变量(1)基本定义和用法(2)静态变量内存图2 静态方法(1)基本定义和用法(2)工具类练习:按下面需求写一个工具类3 static注意事项4 重新认识main方法二、继承1 继承的概念2 继承的特点3 继承到底能继承父类中的哪些内容?4 继…

leetcode日记(63)颜色分类

感觉就是排序问题&#xff1f;我使用的是时间复杂度比较高的简单粗暴排序法&#xff0c;时间复杂度O&#xff08;n^2&#xff09;。 class Solution { public:void sortColors(vector<int>& nums) {int nnums.size();for(int i0;i<n;i){for(int ji1;j<n;j){if…

泛微OA BPM 全程数字化业务介绍、管理、财务一体化 数据业务架构图 上帝视角 02

III.泛微业务、管理、财务一体化过程介绍 IV.低代码平台及典型场景搭建过程 V.全程数字化运营平台价值总结 档案管理 档案接收,四性检测,快速可查找 重要:档案管理:架构总图 业务应用都在一个平台,确保档案实现100%归档 自动化档案采集:自动接收各类档案,如文书档案、合…

速通JS模块化规范

目录 1模块化概述 1.1什么是模块化&#xff1f; 1.2为什么需要模块化&#xff1f; 2有哪些模块化规范&#xff1f; 3导入与导出的概念 4CommonJS 规范 4.1初步体验 4.2导出数据 4.3导入数据 4.4扩展理解 4.5浏览器端运行 5ES6 模块化规范 5.1初步体验 5.2Node 中运…

操作系统课程设计:(JAVA)进程管理系统(附源码zip,jdk11,IDEA Ultimate2024 )

一.题目要求描述 本设计的目的是加深对进程概念及进程管理各部分内容的理解&#xff1b;熟悉进程管理中主要数据结构的设计及进程调度算法、进程控制机构、同步机构及通讯机构的实施。要求设计一个允许n个进程并发运行的进程管理模拟系统。 该系统包括有简单的进程控制、同步与…