HuggingFace学习笔记--Model的使用

news2024/11/15 12:08:39

1--Model介绍

        Transformer的 model 一般可以分为:编码器类型(自编码)、解码器类型(自回归)和编码器解码器类型(序列到序列);

        Model Head(任务头)是在base模型的基础上,根据不同任务而设置的模块;base模型只起到一个编码和建模特征的功能;

简单代码:

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

if __name__ == "__main__":
    # 数据处理
    sen = "弱小的我也有大梦想!"
    tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
    inputs = tokenizer(sen, return_tensors="pt")
        
    # 不带model head的模型调用
    model = AutoModel.from_pretrained("hfl/rbt3", output_attentions=True)
    output1 = model(**inputs)
    print(output1.last_hidden_state.size()) # [1, 12, 768]
    
    # 带model head的模型调用
    clz_model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3", num_labels=10)
    output2 = clz_model(**inputs)
    print(output2.logits.shape) # [1, 10]

2--AutoModel的使用

官方文档

        AutoModel 用于加载模型;

2-1--简单Demo

测试代码:

from transformers import AutoTokenizer, AutoModel

if __name__ == "__main__":
    checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
    tokenlizer = AutoTokenizer.from_pretrained(checkpoint) 
    
    raw_input = ["I love kobe bryant.", "Me too."]
    inputs = tokenlizer(raw_input, padding = "longest", truncation = True, max_length = 512, return_tensors = "pt")
    
    # 加载指定的模型
    model = AutoModel.from_pretrained(checkpoint)
    print("model: \n", model)
    
    outputs = model(**inputs)
    print("last_hidden_state: \n", outputs.last_hidden_state.shape) # 打印最后一个隐层的输出维度
    # [2 7 768] batch_size为2,7个token,每个token的维度为768

输出结果:

last_hidden_state: 
 torch.Size([2, 7, 768])

# 最后一个隐层的输出
# batchsize为2,表示两个句子
# 7表示token数,每一个句子有7个token
# 768表示特征大小,每一个token的维度为768

测试代码:

from transformers import AutoTokenizer, AutoModelForSequenceClassification

if __name__ == "__main__":
    checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
    tokenlizer = AutoTokenizer.from_pretrained(checkpoint) 
    
    raw_input = ["I love kobe bryant.", "Me too."]
    inputs = tokenlizer(raw_input, padding = "longest", truncation = True, max_length = 512, return_tensors = "pt")

    model2 = AutoModelForSequenceClassification.from_pretrained(checkpoint) # 二分类任务
    print(model2)
    outputs2 = model2(**inputs)
    print(outputs2.logits.shape)

运行结果:

torch.Size([2, 2])
# 两个句子,每个句子二分类的概率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1273149.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows11如何让桌面图标的箭头消失(去掉快捷键箭头)

在Windows 11中,桌面图标的箭头是快捷方式图标的一个标志,用来表示该图标是一个指向文件、文件夹或程序的快捷方式。如果要隐藏这些箭头,你需要修改Windows注册表或使用第三方软件。 在此之前,我需要提醒你,修改注册表…

【unity实战】如何更加规范的创建各种Rogue-Lite(肉鸽)风格的物品和BUFF效果(附项目源码)

文章目录 前言定义基类实现不同的BUFF效果一、回血BUFF1. 简单的回血效果实现2. BUFF层数控制回血量 二、攻击附带火焰伤害三、治疗领域1. 简单的治疗领域实现2. 添加技能冷却时间 通过拾取物品获取对应的BUFF参考源码完结 前言 当创建各种Rogue-Lite(肉鸽&#xf…

VS2022使用Vim按键

VS2022使用Vim按键 在插件管理里面搜索VsVim 点击安装,重启VS 工具->选项->VsVim 配置按键由谁处理,建议Ctrl C之类常用的使用VS处理,其它使用Vim处理

shell编程系列(7)-使用wc进行文本统计

文章目录 前言wc命令的使用wc命令的参数说明:统计字数统计行数打印文本行号 结语 前言 统计功能也是我们在shell编程中经常碰到的一个需求,wc命令可以适用于任何需要统计的数据,不只是统计文本,配合ls命令我们可以统计文件的个数…

electron调用dll问题总汇

通过一天的调试安装,electron调用dll成功,先列出当前的环境:node版本: 18.12.0,32位的(因为dll为32位的) VS2019 python node-gyp 1、首先要查看报错原因,通常在某一行会有提示,常…

在Linux上安装KVM虚拟机

一、搭建KVM环境 KVM(Kernel-based Virtual Machine)是一个基于内核的系统虚拟化模块,从Linux内核版本2.6.20开始,各大Linux发行版就已经将其集成于发行版中。KVM与Xen等虚拟化相比,需要硬件支持的完全虚拟化。KVM由内…

vue3 router-view 使用keep-alive报错parentcomponent.ctx.deactivate is not a function

问题 如下图,在component组件上添加v-if判断,会报错: parentcomponent.ctx.deactivate is not a function 解决方法 去除v-if,将key直接添加上。由于有的公用页面,需要刷新,不希望缓存,所以需要添加key…

2023/11/30JAVAweb学习

数组json形式 想切换实现类,只需要只在你需要的类上添加 Component 如果在同一层,可以更改扫描范围,但是不推荐这种方法 注入时存在多个同类型bean解决方式

C 中的结构 - 存储、指针、函数和自引用结构

0. 结构体的内存分配 当声明某种类型的结构变量时,结构成员被分配连续(相邻)的内存位置。 struct student{char name[20];int roll;char gender;int marks[5];} stu1; 此处,内存将分配给name[20]、roll、gender和marks[5]。st1这…

Redis学习文档

目录 一、概念1、特征2、关系型数据库和非关系型数据库的区别3、键的结构4、Redis的Java客户端5、缓存更新策略5.1、概念5.2、代码 6、缓存穿透6.1、含义6.2、解决办法6.3、缓存空值代码举例6.4、布隆过滤器代码举例 7、缓存击穿7.1、概念7.2、解决办法7.3、互斥锁代码举例7.4、…

卡码网语言基础课 | 17. 判断集合成员

目录 一、 set 集合 二、 创建集合 2.1 引入头文件 2.2 创建 2.3 插入元素 2.4 删除元素 三、 find的用法 四、 实现基本解题 五、 延伸拓展 题目:编写一个程序,判断给定的整数 n 是否存在于给定的集合中。 输入描述: 有多组测试…

Pycharm中使用matplotlib绘制动态图形

Pycharm中使用matplotlib绘制动态图形 最终效果 最近用pycharm学习D2L时发现官方在jupyter notebook交互式环境中能动态绘制图形,但是在pycharm脚本环境中只会在最终 plt.show() 后输出一张静态图像。于是有了下面这段自己折腾了一下午的代码,用来在pych…

jetson nano SSH远程连接(使用MobaXterm)

文章目录 SSH远程连接1.SSH介绍2.准备工作3.连接步骤3.1 IP查询3.2 新建会话和连接 SSH远程连接 本节课的实现,需要将Jetson Nano和电脑保持在同一个局域网内,也就是连接同一个路 由器,通过SSH的方式来实现远程登陆。 1.SSH介绍 SSH是一种网…

一文讲透Python机器学习特征选择之互信息法

1.互信息法的基本思想 互信息(Mutual Information,MI)的基本思想是计算每个特征变量与目标变量之间的互信息统计量,互信息统计量衡量变量之间的依赖关系。两个随机变量之间的互信息统计量肯定是非负值,当且仅当两个随…

带键扫的LED专用驱动方案

一、基本概述 TM1650 是一种带键盘扫描接口的LED(发光二极管显示器)驱动控制专用电路。内部集成有MCU输入输出控制数字接口、数据锁存器、LED 驱动、键盘扫描、辉度调节等电路。TM1650 性能稳定、质量可靠、抗干扰能力强,可适用于24 小时长期…

【强迫症患者必备】SpringBoot项目中Mybatis使用mybatis-redis开启三级缓存必须创建redis.properties优化方案

springboot项目中mybatis使用mybatis-redis开启三级缓存需要创建redis.properties优化方案 前言下载mybatis-redis源码分析RedisCache 代码RedisConfigurationBuilder的parseConfiguration方法 优化改造1.创建JedisConfig类2.复制RedisCache代码创建自定义的MyRedisCache3.指定…

分享超实用的软文撰写步骤!建议收藏

一想到写软文就头大,根本不知道从哪里下手,这是很多写手在创作过程中会遇到的问题。 一篇软文写得好不好,关键就要看你的创作步骤到不到位,软文创作是有一套可执行的具体方式的,跟着步骤来,你也能轻轻松松…

【java扫盲贴】final修饰变量

引用类型:地址不可变 //Java中的引用类型分为类(class)、接口(interface)、数组(array)和枚举(enum)。//string是特殊的引用类型,他的底层是被final修饰的字…

麒麟操作系统网桥配置

网桥概念: Bridge 是 Linux 上用来做 TCP/IP 二层协议交换的设备,其功能可 以简单的理解为是一个二层交换机或者 Hub;多个网络设备可以连接 到同一个 Bridge,当某个设备收到数据包时,Bridge 会将数据转发 给其他设备。…

osgFX扩展库-刻线特效、立方图镜面高光特效(2)

刻线特效 刻线特效(osgFX::Scribe)是一个双通道的特效,第一个通道以通常的方式渲染图形,第二个通道使用线框模式。用户设置好光照和材质之后,即可使用指定的颜色进行渲染。这个特效使用了PolygonOffset渲染属性类来避免多边形斑驳(Z-fighting…