Transformer - Positional Encoding 位置编码 代码实现

news2024/12/24 7:38:13

Transformer - Positional Encoding 位置编码 代码实现

flyfish

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x +  self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

# 词嵌⼊维度是64维
d_model = 64
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=60

x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
                           
pe_result = pe(x)

print("pe_result:", pe_result)

绘图

import numpy as np
import matplotlib.pyplot as plt
# 创建⼀张15 x 5⼤⼩的画布
plt.figure(figsize=(15, 5))

pe = PositionalEncoding(d_model, 0, max_len)

y = pe(torch.zeros(1, max_len, d_model))


# 只查看3,4,5,6维的值.
plt.plot(np.arange(max_len), y[0, :, 3:7].data.numpy())

plt.legend(["dim %d"%p for p in [3,4,5,6]])

在这里插入图片描述

register_buffer 的测试

# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.fc2 =nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    
        self.tmp = torch.randn(size=(1, 3))
        pe = torch.randn(size=(1, 3))
       
        
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)

print(torch.__version__)

root="mydir/"

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)


criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

epochs = 1
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0

    net.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
        
        optimizer.zero_grad()
 
        out = net(images)
      
        loss = criterion(out, labels)
       
        train_loss += loss.item()
        train_acc += (out.max(1)[1] == labels).sum().item()
      
        loss.backward()
    
        optimizer.step()
    
        avg_train_loss = train_loss / len(train_loader.dataset)
        avg_train_acc = train_acc / len(train_loader.dataset)

    net.eval()
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
            out = net(images)
            loss = criterion(out, labels)
            val_loss += loss.item()
            acc = (out.max(1)[1] == labels).sum()
            val_acc += acc.item()
    avg_val_loss = val_loss / len(test_loader.dataset)
    avg_val_acc = val_acc / len(test_loader.dataset)
    print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
                   .format(epoch+1, epochs, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
    
    
    


dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model.pt")
torch.save(net.state_dict(), model_save_path)

model = MLPNet()
model.load_state_dict(torch.load(model_save_path))


print(model.tmp)
print(model.pe)
# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.fc2 =nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    
        self.tmp = torch.randn(size=(1, 3))
        pe = torch.randn(size=(1, 3))
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)


    

dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)


model_save_path = os.path.join(dir_name, "model.pt")



model = MLPNet()
model.load_state_dict(torch.load(model_save_path))


print(model.tmp)
print(model.pe)

从模型加载的pe值,从未改变

tensor([[0.0566, 0.8944, 0.0873]])
tensor([[ 0.2529,  0.5227, -0.2610]])
tensor([[ 0.4632, -0.2602, -1.0032]])
tensor([[-0.3486,  0.8183, -1.3838]])
tensor([[ 0.7163,  0.5574, -0.0848]])
tensor([[-0.3415, -0.9013, -1.6136]])
tensor([[ 0.5490,  1.7691, -1.1375]])
tensor([[-0.3486,  0.8183, -1.3838]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1564267.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

加速科技高性能数模混合信号测试设备ST2500EX精彩亮相SEMICON China 2024

芯片是现代信息技术发展的重要支柱,半导体设备则是芯片产业发展的重要基石。近年来,半导体设备领域开启了国产自研的黄金浪潮,其中,测试机作为芯片测试中至关重要的核心设备之一,国产自研率较低,一直是国内…

基于深度学习的商品标签识别系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)

摘要:本文深入研究了基于YOLOv8/v7/v6/v5的商品标签识别,核心采用YOLOv8并整合了YOLOv7、YOLOv6、YOLOv5算法,进行性能指标对比;详述了国内外研究现状、数据集处理、算法原理、模型构建与训练代码,及基于Streamlit的交…

安装部署 ESXI 5.5版本

1.什么是虚拟化 虚拟化就是把硬件资源从物理方式转变为逻辑方式,打破原有物理结构,使用户可以灵活管理这些资源,并且允许1台物理机上同时运行多个操作系统,以实现资源利用率最大化和灵活管理的一项技术。 2.虚拟化的优势 (1)减少服…

OpenAI 宣布, ChatGPT 网页端无需注册就能立即使用(2024年4月1日)

今天,OpenAI宣布,为了让更多人轻松体验人工智能的强大功能,现在无需注册账户即可立即使用 ChatGPT。这一变化是他们使命的核心部分,即让像 ChatGPT 这样的工具广泛可用,让世界各地的人们都能享受到 AI 带来的好处。 网…

车载以太网AVB交换机 gPTP透明时钟 6口 DB9接口 千兆车载以太网交换机

SW1100千兆车载以太网交换机 一、设备简要分析 8端口千兆和百兆混合车载以太网交换机,其中包含2个通道的1000BASE-T1接口,5通道100BASE-T1接口和1个通道1000BASE-T标准以太网(RJ45接口),可以实现车载以太网多通道交换,千兆和百兆…

人工智能+的广泛应用,已渗透到生活的方方面面

引言 随着科技的不断进步和人工智能技术的快速发展,我们正处于一个人工智能时代。人工智能不仅仅是一种技术,更是一种革命性的变革力量,它正在以前所未有的方式改变着我们的生活和工作方式。 人工智能(AI)指的是人工…

57 npm run build 和 npm run serve 的差异

前言 npm run serve 和 npm run build 的差异 这里主要是从 vue-cli 的流程 来看一下 我们经常用到的这两个命令, 他到传递给 webpack 打包的时候, 的一个具体的差异, 大致是配置了那些东西? 经过了那些流程 ? vue-cli 的 vue-plugin 的加载 内置的 plugin 列表如下, 依次…

RFID:锂电池自动化产线的智能监护者

RFID:锂电池自动化产线的智能监护者 一个拥有尖端工业科技的黑灯工厂里,自动化技术已经代替大部分的人工,在每天的自动化生产中会有大量的产品问世。但是人员少,自动化多的工厂怎么做生产管理,产品溯源呢?…

FebHost:人工智能时代的新宠儿.AI域名

近年来,人工智能技术在各行各业迅猛发展,正在深刻改变着我们的生活。作为AI领域的专属域名,.AI域名正成为越来越多企业和个人的首选。 那么,.AI域名到底是什么呢?它是一种特殊的顶级域名(Top-Level Domain, TLD),于2013年由 安哥拉政府正式退出。与其他通用顶级域名如.com、.…

【Angular】什么是Angular中的APP_BASE_HREF

1 概述: 在这篇文章中,我们将看到Angular 10中的APP_BASE_HREF是什么以及如何使用它。 APP_BASE_HREF为当前页面的基础href返回一个预定义的DI标记。 APP_BASE_HREF是应该被保留的URL前缀。 2 语法: provide: APP_BASE_HREF, useValue: /gfgapp3 步骤: 在app.m…

dataloader numworkers

numworkers是加载数据的额外cpu数量(也可以看成额外的进程)。可以理解是: dataset中的getitem只能得到单个数据, 而numworker设置后是同时加载numwork个数据到RAM中,当需要数据时,不会重新执行getiem的方法…

鸿蒙OS元服务开发:【(Stage模型)设置应用主窗口】

一、设置应用主窗口说明 在Stage模型下,应用主窗口由UIAbility创建并维护生命周期。在UIAbility的onWindowStageCreate回调中,通过WindowStage获取应用主窗口,即可对其进行属性设置等操作。还可以在应用配置文件中设置应用主窗口的属性&…

使用docker-tc对host容器进行限流

docker-tc是一个github开源项目,项目地址是https://github.com/lukaszlach/docker-tc。 运行docker-tc docker run -d \ --name docker-tc \ --network host \ --cap-add NET_ADMIN \ --restart always \ -v /var/run/docker.sock:/var/run/docker.sock \ -v /var…

上位机图像处理和嵌入式模块部署(qmacviusal边缘宽度测量)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面有一篇文章,我们了解了测量标定是怎么做的。即,我们需要提前知道测量的方向,灰度的方向,实际的…

【Error】log依赖冲突

启动项目报错: 原因: web模块存在两个log依赖,存在冲突 解决方案: 使用依赖分析插件删除多出的依赖:

蓝桥杯 - 走迷宫

解题思路: 经典dfs题目,需要重点掌握。 养成好习惯,静态方法都要用到的变量提前想到定义为静态常量。 import java.util.Scanner;public class Main {//注意加static,经常忘记导致编译错误static int N, M, x1, x2, y1, y2, mi…

总结jvm中GC机制(垃圾回收)

前言 本篇博客博主将介绍jvm中的GC机制,坐好板凳发车啦~~ 一.GC相关 1.1回收栈内存 对于虚拟机栈,本地方法栈这部分区域而言,其生命周期与相关线程相关,随线程而生,随线程而灭。并且这三个区域的内存分配与回收具有…

房间预定小程序怎么做_打造用户的专属空间预定小程序

在这个快节奏的时代,人们对于便捷、高效的生活方式有着越来越高的追求。无论是出差、旅行还是日常生活,一个好的住宿环境都是必不可少的。然而,传统的房间预定方式往往让人头疼不已,电话沟通、排队等待、繁琐的手续……这些问题不…

Flutter开发之图片选择器

使用FLutter开发了一个图片选择的组件,功能如下: 1、支持设置最大可选图片的个数; 2、根据选择的图片个数自适应容器组件的高度; 3、可设置容器的最大高度; 4、支持点击放大和删除功能; 具体效果如下 …

Java解析实体类的属性和属性注释

前言 获取某个类的属性(字段)是我们经常都会碰到的,通常我们是通过反射来获取的。 但是有些特殊情况下,我们不仅要获取类的属性,还需要获取属性注释。这种情况下,我们只能通过注解去获取注释。可以自己定…