在Colab上测试Mamba

news2025/2/28 15:43:29

我们在前面的文章介绍了研究人员推出了一种挑战Transformer的新架构Mamba

他们的研究表明,Mamba是一种状态空间模型(SSM),在不同的模式(如语言、音频和时间序列)中表现出卓越的性能。为了说明这一点,研究人员使用Mamba-3B模型进行了语言建模实验。该模型超越了基于相同大小的Transformer的其他模型,并且在预训练和下游评估期间,它的表现与大小为其两倍的Transformer模型一样好。

Mamba的独特之处在于它的快速处理能力,选择性SSM层,以及受FlashAttention启发的硬件友好设计。这些特点使Mamba超越Transformer(Transformer没有了传统的注意力和MLP块)。

有很多人希望自己测试Mamba的效果,所以本文整理了一个能够在Colab上完整运行Mamba代码,代码中还使用了Mamba官方的3B模型来进行实际运行测试。

首先我们安装依赖,这是官网介绍的:

 !pip install causal-conv1d==1.0.0
 !pip install mamba-ssm==1.0.1

然后直接使用transformers库读取预训练的Mamba-3B

 import torch
 import os
 from transformers import AutoTokenizer
 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
 tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
 model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-2.8b"), device="cuda", dtype=torch.bfloat16)

可以看到,3b的模型有11G

然后就是测试生成内容

 tokens = tokenizer("What is the meaning of life", return_tensors="pt")
 input_ids = tokens.input_ids.to(device="cuda")
 max_length = input_ids.shape[1] + 80
 fn = lambda: model.generate(
         input_ids=input_ids, max_length=max_length, cg=True,
         return_dict_in_generate=True, output_scores=True,
         enable_timing=False, temperature=0.1, top_k=10, top_p=0.1,)
 out = fn()
 print(tokenizer.decode(out[0][0]))

这里还有一个chat的示例

 import torch
 from transformers import AutoTokenizer
 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
 
 device = "cuda"
 tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
 tokenizer.eos_token = "<|endoftext|>"
 tokenizer.pad_token = tokenizer.eos_token
 tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
 
 model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)
 
 
 messages = []
 user_message = """
 What is the date for announcement
 On August 10 said that its arm JSW Neo Energy has agreed to buy a portfolio of 1753 mega watt renewable energy generation capacity from Mytrah Energy India Pvt Ltd for Rs 10,530 crore.
  """
 
 messages.append(dict(role="user",content=user_message))
 input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
 out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
 decoded = tokenizer.batch_decode(out)
 messages.append(dict(role="assistant",content=decoded[0].split("<|assistant|>\n")[-1]))
 print("Model:", decoded[0].split("<|assistant|>\n")[-1])

这里我将所有代码整理成了Colab Notebook,有兴趣的可以直接使用:

https://avoid.overfit.cn/post/ed2d2cc2460d4e0683a270e2761e10ea

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1382531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MyBatis第二课,灰度发布,@Results注解,使用xml书写mysql

目录 打印MyBatis的日志配置&#xff1a; 灰度发布:指发布环境&#xff0c;比如发布环境有200台机器&#xff0c;发布的时候是一批一批的机器的发布 2.删除与修改 使用Results注解&#xff0c;这样就和上面的别名一个意思&#xff0c;column是数据库的列 自动转驼峰&#…

ubuntu的动图截屏怎么做

在Ubuntu系统中&#xff0c;你可以通过以下步骤来截取动图&#xff08;即屏幕录制并转换为GIF格式&#xff09;&#xff1a; 1,首先&#xff0c;你需要安装一些必要的工具。打开终端并输入以下命令以安装gtk-recordmydesktop&#xff08;用于录制屏幕&#xff09;、mplayer&am…

【快速解决】保姆级Anaconda安装教程

目录 第一步 ​编辑第二步 ​编辑第三步 第四步 第五步 第六步 ​编辑 第七步 第八步 第九步 第一步 在anaconda清华大学开源软件镜像站下载anaconda。点击这里进入 我这里选的是windows-x86_64。 第二步 下载好以后进行安装 第三步 第四步 第五步 选择…

VR全景博物馆——让博物馆“火起来”

不管是十里洋场的繁华、还是红岩革命的英勇&#xff0c;博物馆一直都拥有着丰富的历史沉淀和文化底蕴&#xff0c;通过VR全景拍摄制作技术&#xff0c;我们可以随时随地穿越空间&#xff0c;去切身体验那些历史人物的经历。 传统的实体博物馆受限于地理位置和布局&#xff0c;使…

Google cloud认证必备

Google cloud认证 ​这个可以走代理合作 ​价格优美 ​通过保证

线上问题整理

JVM 案例 案例一&#xff1a;服务器内存不足&#xff0c;影响Java应用 问题&#xff1a; 收到报警&#xff0c;某Java应用集群中一台服务器可用内存不足&#xff0c;超过报警阈值。 排查过程&#xff1a; 首先&#xff0c;通过Hickwall查看该应用各项指标&#xff0c;发现无论…

如何创建一个pytorch的训练数据加载器(train_loader)用于批量加载训练数据

Talk is cheap,show me the code! 哈哈&#xff0c;先上几段常用的代码&#xff0c;以语义分割的DRIVE数据集加载为例&#xff1a; DRIVE数据集的目录结构如下&#xff0c;下载链接DRIVE,如果官网下不了&#xff0c;到Kaggle官网可以下到&#xff1a; 1. 定义DriveDataset类&…

Qt OpenGL - 网格式的直角坐标系

Qt OpenGL - 网格式的直角坐标系 引言一、绘制3D网格1.1 绘制平行于y轴的线段1.2 绘制平行于三个轴的线段1.3 绘制不同的3D网格 二、网格式的直角坐标系三、参考链接 引言 在OpenGL进行3D可视化&#xff0c;只绘制三条坐标轴略显单薄&#xff0c;而绘制网格形式的坐标系则能更清…

更换为mainwindow.ui更新工程架构

文章目录 前言一、新建带mainwindow.ui的工程1.新建工程2. 添加工程模块添加opencv的库3.添加资源3.1工程上添加资源3.2引用资源 4.添加曲线文件4.1 复制关键文件到新工程4.2 新进显示曲线的ui带.h的为了方面名字取一样4.3添加曲线显示控件4.4 添加工具 5. 添加曲线.h文件内容6…

大数据深度学习ResNet深度残差网络详解:网络结构解读与PyTorch实现教程

文章目录 大数据深度学习ResNet深度残差网络详解&#xff1a;网络结构解读与PyTorch实现教程一、深度残差网络&#xff08;Deep Residual Networks&#xff09;简介深度学习与网络深度的挑战残差学习的提出为什么ResNet有效&#xff1f; 二、深度学习与梯度消失问题梯度消失问题…

Apache-Common-Pool2中对象池的使用方式

最近在工作中&#xff0c;对几个产品的技术落地进行梳理。这个过程中发现一些朋友对如何使用Apache的对象池存在一些误解。所以在写作“业务抽象”专题的空闲时间里&#xff0c;本人觉得有必要做一个关于对象池的知识点和坑点讲解。Apache Common-Pool2 组件最重要的功能&#…

nvm安装高版本Nodejs报错

文章概叙 之前使用1.1.17版本的nvm&#xff0c;切换使用18的Nodejs的时候报错&#xff0c;经过短暂的思考&#xff0c;决定使用1.1.12的nvm的无聊故事。 吐槽 今天的故事比较无奈&#xff0c;由于某些原因&#xff0c;现在需要做rn的开发&#xff0c;至于为啥不是flutter&am…

《工具录》dig

工具录 1&#xff1a;dig2&#xff1a;选项介绍3&#xff1a;示例4&#xff1a;其他 本文以 kali-linux-2023.2-vmware-amd64 为例。 1&#xff1a;dig dig 是域名系统&#xff08;DNS&#xff09;查询工具&#xff0c;常用于域名解析和网络故障排除。比 nslookup 有更强大的功…

一张图总结架构设计的40个黄金法则

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;很多小伙伴拿到非常优质的架构机会&#xff0c;常常找尼恩求助&#xff1a; 尼恩&#xff0c;我这边有一个部门技术负责人资深架构师的机会&#xff0c;非常难得&#xff0c; 但是有一个大厂高P在抢&#xff0…

为什么很多公司选择不升级JDK版本,仍然使用JDK8?

在讨论为什么许多公司选择不升级JDK版本&#xff0c;而继续使用JDK 8时&#xff0c;我们需要从多个角度来分析这个问题。以下是根据您提供的背景信息进行的一些分析和真实案例。 本文已收录于&#xff0c;我的技术网站 ddkk.com&#xff0c;有大厂完整面经&#xff0c;工作技术…

H5网站封装成App的高效转换之旅

在移动互联网时代&#xff0c;App&#xff08;应用程序&#xff09;和H5&#xff08;HTML5网站&#xff09;是两种常见的移动解决方案。App通常提供更流畅的用户体验和更丰富的功能&#xff0c;而H5网站则以其开发成本低、更新快捷和无需安装等优势受到青睐。尽管如此&#xff…

【java八股文】之Spring系列篇

1、你怎么理解Spring&#xff1f; Spring是个轻量级的框架&#xff0c;简化了应用的开发程序&#xff0c;提高开发人员的系统维护性&#xff0c;不过配置消息比较繁琐&#xff0c;所以后面才出选了SpringBoot的框架。 Spring的核心组件 &#xff1a; Spring Core 、 Spring Con…

Video接口介绍

屏库 https://m.panelook.cn/index_cn.php Open LDI, open lvds display interface OpenLDI and LVDS是兼容的&#xff0c; 是一种电平 https://www.ti2k.com/178597.html MIPI DSI/Camera crosLink FPD-LINK(Flat panel display link)是National(TI) LVDS技术&#xff0c; …

Openstack云计算(六)Openstack环境对接ceph

一、实施步骤&#xff1a; &#xff08;1&#xff09;客户端也要有cent用户&#xff1a; useradd cent && echo "123" | passwd --stdin cent echo -e Defaults:cent !requiretty\ncent ALL (root) NOPASSWD:ALL | tee /etc/sudoers.d/ceph chmod 440 /et…

[足式机器人]Part2 Dr. CAN学习笔记-Advanced控制理论 Ch04-12+13 不变性原理+非线性系统稳定设计

本文仅供学习使用 本文参考&#xff1a; B站&#xff1a;DR_CAN Dr. CAN学习笔记-Advanced控制理论 Ch04-1213 不变性原理非线性系统稳定设计 1. Invariance Princilpe-LaSalle;s Theorem不变性原理2. Nonlinear Basic Feedback Stabilization 非线性系统稳定设计 1. Invarianc…