Mamba环境配置教程【自用】

news2025/1/13 8:39:44

1. 新建一个Conda虚拟环境

conda create -n mamba python=3.10

在这里插入图片描述

2. 进入该环境

conda activate mamba

在这里插入图片描述

3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio
直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、torchaudio
CSDN资源
在这里插入图片描述

下载完成后,进入这些文件的目录下,直接使用下面三个指令进行安装即可

pip install torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl 
pip install torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl 
pip install torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl

4. 安装triton和transformers库

pip install triton==2.3.1
pip install transformers==4.43.3

5. 安装完这些我们最基本Pytorch环境以及配置完成,接下来就是Mamba所需的一些依赖了,由于Mamba需要底层的C++进行编译,所以还需要手动安装一下cuda-nvcc这个库,直接使用conda命令即可

conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

6. 最后就是下载最重要的 causal-conv1d 和mamba-ssm库。在这里我们同样选择离线安装的方式,来避免大量奇葩的编译bug。首先进入下面各自的github网址种进行下载对应版本
causal-conv1d —— 1.4.0
在这里插入图片描述
mamba-ssm —— 2.2.2
在这里插入图片描述
和安装pytorch一样,进入下载的.whl文件所在文件夹,直接使用以下指令进行安装

pip install causal_conv1d-1.4.0+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install mamba_ssm-2.2.2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

7. 安装好环境后,验证一下Mamba块能否成功运行,直接复制下面代码保存问mamba2_test.py,并运行

# Copyright (c) 2024, Tri Dao, Albert Gu.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

try:
    from causal_conv1d import causal_conv1d_fn
except ImportError:
    causal_conv1d_fn = None

try:
    from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:
    RMSNormGated, LayerNorm = None, None

from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined


class Mamba2Simple(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=128,
        d_conv=4,
        conv_init=None,
        expand=2,
        headdim=64,
        ngroups=1,
        A_init_range=(1, 16),
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        dt_limit=(0.0, float("inf")),
        learnable_init_states=False,
        activation="swish",
        bias=False,
        conv_bias=True,
        # Fused kernel and sharding options
        chunk_size=256,
        use_mem_eff_path=True,
        layer_idx=None,  # Absorb kwarg for general module
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.conv_init = conv_init
        self.expand = expand
        self.d_inner = self.expand * self.d_model
        self.headdim = headdim
        self.ngroups = ngroups
        assert self.d_inner % self.headdim == 0
        self.nheads = self.d_inner // self.headdim
        self.dt_limit = dt_limit
        self.learnable_init_states = learnable_init_states
        self.activation = activation
        self.chunk_size = chunk_size
        self.use_mem_eff_path = use_mem_eff_path
        self.layer_idx = layer_idx

        # Order: [z, x, B, C, dt]
        d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)

        conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
        self.conv1d = nn.Conv1d(
            in_channels=conv_dim,
            out_channels=conv_dim,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=conv_dim,
            padding=d_conv - 1,
            **factory_kwargs,
        )
        if self.conv_init is not None:
            nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
        # self.conv1d.weight._no_weight_decay = True

        if self.learnable_init_states:
            self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
            self.init_states._no_weight_decay = True

        self.act = nn.SiLU()

        # Initialize log dt bias
        dt = torch.exp(
            torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        )
        dt = torch.clamp(dt, min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        self.dt_bias = nn.Parameter(inv_dt)
        # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
        # name.endswith("bias") in param_grouping.py
        self.dt_bias._no_weight_decay = True

        # A parameter
        assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
        A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
        A_log = torch.log(A).to(dtype=dtype)
        self.A_log = nn.Parameter(A_log)
        # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.nheads, device=device))
        self.D._no_weight_decay = True

        # Extra normalization layer right before output projection
        assert RMSNormGated is not None
        self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

    def forward(self, u, seq_idx=None):
        """
        u: (B, L, D)
        Returns: same shape as u
        """
        batch, seqlen, dim = u.shape

        zxbcdt = self.in_proj(u)  # (B, L, d_in_proj)
        A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)
        initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
        dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)

        if self.use_mem_eff_path:
            # Fully fused path
            out = mamba_split_conv1d_scan_combined(
                zxbcdt,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.dt_bias,
                A,
                D=self.D,
                chunk_size=self.chunk_size,
                seq_idx=seq_idx,
                activation=self.activation,
                rmsnorm_weight=self.norm.weight,
                rmsnorm_eps=self.norm.eps,
                outproj_weight=self.out_proj.weight,
                outproj_bias=self.out_proj.bias,
                headdim=self.headdim,
                ngroups=self.ngroups,
                norm_before_gate=False,
                initial_states=initial_states,
                **dt_limit_kwargs,
            )
        else:
            z, xBC, dt = torch.split(
                zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
            )
            dt = F.softplus(dt + self.dt_bias)  # (B, L, nheads)
            assert self.activation in ["silu", "swish"]

            # 1D Convolution
            if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
                xBC = self.act(
                    self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
                )  # (B, L, self.d_inner + 2 * ngroups * d_state)
                xBC = xBC[:, :seqlen, :]
            else:
                xBC = causal_conv1d_fn(
                    x=xBC.transpose(1, 2),
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                ).transpose(1, 2)

            # Split into 3 main branches: X, B, C
            # These correspond to V, K, Q respectively in the SSM/attention duality
            x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
            y = mamba_chunk_scan_combined(
                rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
                dt,
                A,
                rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
                rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
                chunk_size=self.chunk_size,
                D=self.D,
                z=None,
                seq_idx=seq_idx,
                initial_states=initial_states,
                **dt_limit_kwargs,
            )
            y = rearrange(y, "b l h p -> b l (h p)")

            # Multiply "gate" branch and apply extra normalization layer
            y = self.norm(y, z)
            out = self.out_proj(y)
        return out

if __name__ == '__main__':
    model = Mamba2Simple(256).cuda()
    inputs = torch.randn(2, 128, 256).cuda()
    pred = model(inputs)
    print(pred.size())      

在这里插入图片描述

参考文献

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2142537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VTD激光雷达(5)——05_OptiX_GPU

文章目录 前言一、总结 前言 一、 1 2 3 总结

随着访问范围的扩大 OpenAI o1-mini 现已向免费用户开放

上周,OpenAI 展示了其最新的大型语言模型(LLM)–OpenAI o1及其小兄弟 OpenAI o1-mini。该公司在公告中称,Plus 和 Team 用户可在公告发布之日起访问该模型。企业和教育用户将在本周获得该模型,而免费用户最终将获得 o1…

线性代数之QR分解和SVD分解

文章目录 1.QR分解Schmidt正交化Householder变换QR分解的应用 2. 求矩阵特征值、特征向量的基本方法3.SVD分解SVD分解的应用 参考文献 1.QR分解 矩阵的正交分解又称为QR分解,是将矩阵分解为一个正交矩阵Q和一个上三角矩阵R的乘积的形式。 任意实数方阵A&#xff0c…

2022高教社杯全国大学生数学建模竞赛C题 问题一(1) Python代码

目录 问题 11.1 对这些玻璃文物的表面风化与其玻璃类型、纹饰和颜色的关系进行分析数据探索 -- 单个分类变量的绘图树形图条形图扇形图雷达图Cramer’s V 相关分析统计检验列联表分析卡方检验Fisher检验绘图堆积条形图分组条形图分类模型Logistic回归随机森林import matplotlib…

SpaceX实现人类首次商业太空行走:航天历史新篇章

导语 2023年9月,SpaceX成功完成了人类历史上首次商业太空行走,这不仅是航天领域的重要突破,也是商业航天的一次重大胜利。这一事件标志着普通人离太空更近了一步,为未来的太空探索和火星移民奠定了基础。 一、背景介绍&#xff1a…

【MySQL】数据类型【mysql当中各自经典的数据类型的学习和使用】

目录 数据类型1数据类型分类2.数值类型2.1tinyint类型2.2bit类型2.3小数类型2.3.1float2.3.2decimal 3.字符串类型3.1char3.2varchar3.3char和varchar的对比 4.日期和时间类型5.enum和set5.1对enum和set进行插入5.2对enum和set进行查询 数据类型 数据类型本身就是mysql当中天然…

YOLOv8 的安装与训练

YOLOv8 是 YOLO 系列实时目标检测器中的较新迭代版本,在准确性和速度方面提供了前沿性能。基于之前 YOLO 版本的进步,YOLOv8 引入了新的特性和优化,使其成为各种应用中各种目标检测任务的理想选择。 一、安装显卡驱动与CUDA: 这个…

成都院干翻华东院成第一水电设计院!

注:文章来源于百度。版权归原作者所有。。 昨天中国电建发布了2024年中报,一般中报我是不怎么研究的,除非利益相关。 但今年电建的中报亮点很多,其中最显眼的就要数电建成都院在上半年干翻了传统龙头老大——华东院。 在净利润…

【C++】一次rustdesk-server编译记录

RustDesk Server 是一个开源的远程桌面解决方案,允许用户自托管自己的远程桌面服务器。该项目是免费且开源的,支持多种平台和环境。RustDesk Server 提供了 ID/Rendezvous 服务器和 Relay 服务器,以及一些 CLI 工具,方便用户进行远…

[Redis] Redis中的set和zset类型

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

动手学深度学习8.5. 循环神经网络的从零开始实现-笔记练习(PyTorch)

本节课程地址:从零开始实现_哔哩哔哩_bilibili 本节教材地址:8.5. 循环神经网络的从零开始实现 — 动手学深度学习 2.0.0 documentation (d2l.ai) 本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>rnn-scratc…

【优化器】Optimizer——深度学习中的优化器是什么作用呢?

【优化器】Optimizer——深度学习中的优化器是什么作用呢? 【优化器】Optimizer——深度学习中的优化器是什么作用呢? 文章目录 【优化器】Optimizer——深度学习中的优化器是什么作用呢?1.什么是优化器?梯度下降法3. 常见的优化…

数据结构易错整理1

目录 数据结构的基础概念 数据结构基础概念 数据结构的逻辑结构 数据结构的物理结构 算法分析 时间复杂度 例题 数据结构的基础概念 数据结构基础概念 设计存储结构时不仅要存储格数据元素的值,而且还要存储数据元素之间的关系 数据结构具有特定关系的数据…

C/C++语言基础--从C到C++的不同(下),15个部分说明C与C++的不同

本专栏目的 更新C/C的基础语法,包括C的一些新特性 前言 1-10在上篇C/C语言基础–从C到C的不同(上);当然C和C的不同还有很多,本人暂时只总结这些,其他的慢慢更新;上一篇C/C语言基础–从C到C的不同(上&…

Sass实现文字两侧横线及Sass常用方案

Sass常用方案及Sass实现文字两侧横线 1.Sass实现文字两侧横线2.用Sass简化媒体查询3.使用继承占位符实现样式复用4.Sass 模块化5.lighten 和 darken 自我记录 1.Sass实现文字两侧横线 mixin 的基本作用: 代码复用:把常用的样式封装在一起,…

C++和OpenGL实现3D游戏编程【目录】

欢迎来到zhooyu的专栏。 个人主页:【zhooyu】 文章专栏:【OpenGL实现3D游戏编程】 贝塞尔曲面演示: 贝塞尔曲面演示zhooyu 本专栏内容: 我们从游戏的角度出发,用C去了解一下游戏中的功能都是怎么实现的。这一切还是要…

【刷题日记】螺旋矩阵

54. 螺旋矩阵 这个是一道模拟题,但我记得我大一第一次做这道题的时候真的就是纯按步骤模拟,没有对代码就行优化,导致代码写的很臃肿。 有这么几个地方可以改进。 看题目可以知道最终的结果一定是rows*cols个结点,所以只需要遍历rows*cols次…

Docker部署镜像 发布容器 容器网络互联 前端打包

准备工作 导入相关依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-spring-boot3-starter</artifactId><version>3.5.7</version></dependency><dependency><groupId>com.baomidou<…

CLIP论文中关键信息记录

由于clip论文过长&#xff0c;一直无法完整的阅读该论文&#xff0c;故而抽取论文中的关键信息进行记录。主要记录clip是如何实现的的&#xff08;提出背景、训练数据、设计模式、训练超参数、prompt的作用&#xff09;&#xff0c;clip的能力&#xff08;clip的模型版本、clip…

【Python机器学习】序列到序列建模——对序列到序列模型的增强

有两种增强训练序列到序列模型的方法&#xff0c;可以提高模型的精确率和可扩展性。 使用装桶法降低训练复杂度 输入序列可以有不同的长度&#xff0c;这使短序列的训练数据添加了大量填充词条。过多的填充会使计算成本高昂&#xff0c;特别是当大多数序列都很短&#xff0c;…