2023-02-22干活小计

news2025/1/19 11:12:49

复现BERT:

只能说爷今天干了一上午一下午的代码

bert的输入:
batch_size * max_len * emb_num @768 * 768 
bert的输出:三维字符级别特征(NER可能就更适合) 二维篇章级别特征(比如文本分类可能就更适合)
batch_size * max_len * emb_num, batch_size * emb_num

绝对位置编码

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import pandas as pd
import sklearn
import random
import numpy as np

class BertEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.word_embeddings = nn.Embedding(config["vocab_size"], config["hidden_size"])
        self.word_embeddings.weight.requires_grad = True
        self.position_embeddings = nn.Embedding(config["max_len"], config["hidden_size"])
        self.position_embeddings.weight.requires_grad = True
        self.token_type_embeddings = nn.Embedding(config["type_vocab_size"], config["hidden_size"])
        self.token_type_embeddings.weight.requires_grad = True
        self.layernorm = nn.LayerNorm(config["hidden_size"])
        self.dropout = nn.Dropout(config["hidden_dropout_pro"])

    def forward(self, batch_index, batch_seg_idx):
        word_emb = self.word_embeddings(batch_index)
        pos_idx = torch.arange(0, self.position_embeddings.weight.data.shape[0])
        pos_idx = pos_idx.repeat(self.config["batch_size"], 1)
        pos_emb = self.position_embeddings(pos_idx)
        token_emb = self.token_type_embeddings(batch_seg_idx)
        emb = word_emb + pos_emb + token_emb
        layer_norm_emb = self.layernorm(emb)
        dropout_emb = self.dropout(layer_norm_emb)
        return dropout_emb



class BertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = BertEmbedding(config)
        self.bert_layer = nn.Linear(config["hidden_size"], config["hidden_size"])

    def forward(self, batch_index, batch_seg_idx):
        emb = self.embedding(batch_index, batch_seg_idx)
        bert_out1 = self.bert_layer(emb)
        return bert_out1

class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert = BertModel(config)#batch_size * max_len * emb_num @768 * 768 = batch_size * max_len * emb_num, batch_size * emb_num
        self.cls_mask = nn.Linear(config["hidden_size"], config["vocab_size"])
        self.cls_nsp = nn.Linear(config["hidden_size"], 2)

    def forward(self, batch_index, batch_seg_idx):
        bert_out = self.bert(batch_index, batch_seg_idx)


def get_data(file_path):
    all_data = pd.read_csv(file_path)
    all_data = sklearn.utils.shuffle(all_data)
    t1 = all_data["text1"].tolist()
    t2 = all_data["text2"].tolist()
    l = all_data["label"].tolist()
    return t1, t2, l

class BertDataset(Dataset):
    def __init__(self, text1, text2, label, max_len, word_2_index):
        assert len(text1) == len(text2) == len(label), "NSP数据长度不一,复现个锤子!!!"
        self.text1 = text1
        self.text2 = text2
        self.label = label
        self.max_len = max_len
        self.word_2_index = word_2_index

    def __getitem__(self, index):
        #mask_id = [0] * self.max_len
        mask_v = [0] * self.max_len
        text1 = self.text1[index]
        text2 = self.text2[index]
        label = self.label[index]
        n = int((self.max_len-4) / 2)
        text1_id = [self.word_2_index.get(i, self.word_2_index["[UNK]"]) for i in text1][:n]
        text2_id = [self.word_2_index.get(i, self.word_2_index["[UNK]"]) for i in text2][:n]

        #text = text1 + text2
        text_id = [self.word_2_index["[CLS]"]] + text1_id + [self.word_2_index["[SEP]"]] + text2_id + [self.word_2_index["[SEP]"]]
        segment_id = [0] + [0] * len(text1_id) + [0] + [1] * len(text2_id) + [1] + [2] * (self.max_len - len(text_id))
        text_id = text_id + [self.word_2_index["[PAD]"]] * (self.max_len - len(text_id))
        for i, v in enumerate(text_id):
            if v in [self.word_2_index["[PAD]"], self.word_2_index["[SEP]"], self.word_2_index["[UNK]"]]:
                continue
            if random.random() < 0.15:
                r = random.random()
                if r < 0.8:
                    text_id[i] = self.word_2_index["[MASK]"]
                    mask_v[i] = v

                elif r > 0.9:
                    text_id[i] = random.randint(6, len(self.word_2_index)-1)
                    mask_v[i] = v



        return torch.tensor(text_id), torch.tensor(label), torch.tensor(mask_v), torch.tensor(segment_id)

    def __len__(self):
        return len(self.text1)




if __name__ == "__main__":
    text1, text2, label = get_data("..//self_bert//data//self_task2.csv")
    epoch = 1024
    batch_size = 32
    max_len = 256
    with open("..//self_bert//data//index_2_word.text", "r", encoding="utf-8") as f:
        index_2_word = f.read().split("\n")
        word_2_index = {word: index for index, word in enumerate(index_2_word)}

    config ={
        "epoch": epoch,
        "batch_size": batch_size,
        "max_len": max_len,
        "vocab_size": len(word_2_index),
        "hidden_size": 768,
        "type_vocab_size": 3,
        "hidden_dropout_pro": 0.2,

    }






    train_dataset = BertDataset(text1, text2, label, max_len, word_2_index)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

    model = Model(config)


    for e in range(epoch):
        print(f"here is the {e}th epoch")
        for batch_text_index, batch_text_label, batch_mask_value , batch_segment_id in train_dataloader:
            model.forward(batch_text_index, batch_segment_id)

回家看看花书,也许还会谢谢代码,结束!

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/364642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

亲身试验 Outlook防关联方法分享

Outlook在海外的用途是很广泛的&#xff0c;不仅可以用于收发邮件&#xff0c;还可以作为各类第三方网站的登录凭证。所以Microsoft对于Outlook的监管还是比较严格的&#xff0c;跨境卖家大量注册Outlook账号使用的话很容易被检测出关联然后被封号。龙哥针对Outlook防关联的问题…

35-Golang中的方法

Golang中的方法方法的介绍和使用方法的声明和调用方法的调用和传参机制原理方法的声明(定义)方法注意事项和细节讨论方法和函数的区别方法的介绍和使用 在某些情况下&#xff0c;我们需要声明(定义)方法。比如person结构体&#xff0c;除了有一些字段外(年龄&#xff0c;姓名……

unix高级编程-僵尸进程和孤儿进程

僵尸进程&#xff1a; 一个父进程利用fork创建子进程&#xff0c;如果子进程退出&#xff0c;而父进程没有利用wait 或者 waitpid 来获取子进程的状态信息&#xff0c;那么子进程的状态描述符依然保存在系统中。 孤儿进程&#xff1a;一个父进程退出&#xff0c; 而它的一个或…

java+Selenium+TestNg搭建自动化测试架构(3)实现POM(page+Object+modal)

1.Page Object是Selenium自动化测试项目开发实践的最佳设计模式之一&#xff0c;通过对界面元素的封装减少冗余代码&#xff0c;同时在后期维护中&#xff0c;若元素定位发生变化&#xff0c;只需要调整页面元素封装的代码&#xff0c;提高测试用例的可维护性。 PageObject设计…

软件测试,刚进入一个公司如何快速上手一个项目?

目录 前言 客观现状 主观能动性 总结感谢每一个认真阅读我文章的人&#xff01;&#xff01;&#xff01; 重点&#xff1a;配套学习资料和视频教学 前言 刚入职一家新公司&#xff0c;做的项目是之前很少接触的行业&#xff0c;该怎么快速的熟悉并上手自己的工作&#xf…

富文本编辑组件封装,tinymce、tinymce-vue

依赖&#xff1a;package.json yarn add tinymce tinymce/tinymce-vue {"dependencies": {"tinymce/tinymce-vue": "5.0.0","tinymce": "6.3.1","vue": "3.2.45",}, } 本地依赖&#xff1a; 在publ…

JIT-即时编译技术

VM&#xff08;HotSpot&#xff09;执行引擎中包含解释器与JIT编译器热点代码&#xff08;执行多次&#xff09;才有JIT编译的必要&#xff08;JIT编译阈值&#xff09;JVM&#xff08;HotSpot&#xff09;会有两个计数器&#xff08;次数/回边&#xff09;判断方法/代码块是否…

缺少IT人员的服装行业该如何进行数字化转型?

服装行业上、下游产业链长&#xff0c;产品属性复杂&#xff0c;是劳动密集型和技术密集型紧密结合的产物&#xff0c;是典型的实体经济代表。 近二十年是服装业发展的机遇和挑战之年&#xff0c;从“世界工厂”“中国制造”&#xff0c;逐渐向“中国设计”转变,中国服装产业经…

Kotlin新手教程九(协程)

一、协程 协程从Kotlin1.3开始引入&#xff0c;本质上协程就是轻量级的线程。协程的基本功能点有&#xff1a; 轻量&#xff1a;可以在单个线程上运行多个协程&#xff0c;因为协程支持挂起&#xff0c;不会使正在运行协程的线程阻塞。挂起比阻塞节省内存&#xff0c;且支持多…

扬帆优配|雷达供应商Arbe暴涨近50%;A股毫米波雷达概念异军突起

今日早盘&#xff0c;A股全体低开高走&#xff0c;上证指数围绕3300点重复抢夺&#xff0c;两市成交呈现大幅萎缩的趋势&#xff0c;显示市场谨慎情绪较为浓厚。 盘面上&#xff0c;白酒、国防军工、新能源、医药等板块涨幅居前&#xff0c;电信运营、网络游戏、稳妥、房地产等…

Sqoop导出hive/hdfs数据到mysql中---大数据之Apache Sqoop工作笔记006

然后我们看看数据利用sqoop,从hdfs hbase中导出到mysql中去 看看命令可以看到上面这个 这里上面还是mysql的部分,然后看看 下面--num-mappers 这个是指定mapper数 然后下面这个export-dir这里是,指定hdfs中导出数据的目录 比如这里指定的是hive的一个表/user/hive/warehouse…

IOS开发中遇到的问题总结【持续更新】

目录 知识点补给站 1. SwiftUI中的Image控件使用系统图标 知识点补给站 【Swift学习】关于 Swift | Swift 编程语言中文教程&#xff08;The Swift Programming Language&#xff09;【SwiftUI学习】不要惊慌! SwiftUI Example【SwiftUI学习】https://goswiftui.com【AppIcon…

C#、JAVA读写PLC物联网Modbus

Modbus协议是一种常用于工业自动化领域的通信协议&#xff0c;它使用简单、易实现、可靠的特点得到了广泛应用。物联网中的设备也需要使用Modbus协议进行通信。本文将介绍物联网Modbus通信的相关内容。一、Modbus协议简介Modbus协议是一种串行通信协议&#xff0c;它最初由Modi…

浅谈ThreadLocal的原理

文章目录1.ThreadLocal初识2.ThreadLocal底层原理3.ThreadLocal核心API3.1.get()方法3.2.set()方法3.3.remove()方法3.4.核心代码及流程4.ThreadLocalMap5.Hash冲突怎么解决6.ThreadLocal内存泄漏问题及解决办法7.应用场景8.总结1.ThreadLocal初识 ThreadLocal概念&#xff1a…

RPC(2)------ Netty(NIO) + 多种序列化协议 + JDK动态代理实现

依赖包解释 Guava 包含了若干被Google的 Java项目广泛依赖 的核心库&#xff0c;例如&#xff1a;集合 [collections] 、缓存 [caching] 、原生类型支持 [primitives support] 、并发库 [concurrency libraries] 、通用注解 [common annotations] 、字符串处理 [string process…

Windows部署Jar包的三种方式

文章目录1、cmd命令启动2、bat脚本启动2.1 启动jar包2.2 关闭服务3、使用WinSW3.1 重命名3.2 xml配置3.3 安装服务3.4 卸载服务3.5 启动和停止服务1、cmd命令启动 这种方式比较简单&#xff0c;但是窗口关闭后服务也就被杀死了&#xff0c;命令如下 java -jar xxx.jar2、bat脚…

nignx(安装,正反代理,安装tomcat设置反向代理,ip透传)

1安装nginx 安装wget Yum install -y wget 下载(链接从官网找到右键获取) 以下过程root 安装gcc Yum -y install gcc c 安装pcre Yum install -y pcre pcre-devel Openssl Yum install -y openssl openssl-devel 安装zlib Yum install -y zlib zlib-devel 安装make Yum inst…

纯手动搭建hadoop集群记录001_搭建虚拟机_调通网络_配置静态IP_安装JDK---大数据之Hadoop3.x工作笔记0162

1.首先准备机器,172.19.126.115 172.19.126.116 172.19.126.117 我准备了3台 Windows机器 2.然后我打算在Windows机器上使用虚拟机,搭建3台Centos虚拟机来进行安装hadoop 3.这里我们的3台windows机器中的,3台linux虚拟机也使用了3个IP,分别是 172.19.126.120 172.19.126.1…

Redis 删除策略和内存淘汰策略

文章目录一、过期数据二、数据删除策略2-1 定时删除2-2 惰性删除2-3 定期删除三、内存淘汰策略3-1 新数据进入检测3-2 影响数据逐出的相关配置3-3 八种数据逐出策略提示&#xff1a;以下是本篇文章正文内容&#xff0c;Redis系列学习将会持续更新 一、过期数据 Redis中的数据特…

jvm知识点

jvm面试总结 类加载机制? 如何把类加载到jvm中 ? 装载–>链接–>初始化–>使用–>卸载 装载: ClassFile–>字节流–>类加载器将字节流所代表的静态结构转化为方法区的运行时数据结构在我们的堆中生成一个代表这个类的java.lang.Class对象 链接: 验证–…