lit-llama代码解析

news2024/9/22 1:01:45

https://github.com/Lightning-AI/lit-llama/blob/main/README.md

下载的时候会报错误,因为网不行,一种方法就是多次尝试,另一种方法是终端连上代理下载

pycharm连接hugging face等网站_hugging face怎么连接-CSDN博客

根据指引下载权重

下载完权重运行:python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B

转化为.pth文件 

跟着readme/howto教程量化或进行其他操作

warning

UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:455.)
  y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

https://github.com/comfyanonymous/ComfyUI/issues/3202

分析generate

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import sys
import time
import warnings
from pathlib import Path
from typing import Optional

import lightning as L
import torch
print(torch.cuda.is_available())
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import lazy_load, llama_model_lookup, quantization


@torch.no_grad()
def generate(
    model: LLaMA,
    idx: torch.Tensor,
    max_new_tokens: int,
    *,
    max_seq_length: Optional[int] = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

    The implementation of this function is modified from A. Karpathy's nanoGPT.

    Args:
        model: The model to use.
        idx: Tensor of shape (T) with indices of the prompt sequence.
        max_new_tokens: The number of new tokens to generate.
        max_seq_length: The maximum sequence length allowed.
        temperature: Scales the predicted logits by 1 / temperature
        top_k: If specified, only sample among the tokens with the k highest probabilities
        eos_id: If specified, stop generating any more token once the <eos> token is triggered
    """
    # create an empty tensor of the expected final shape and fill in the current tokens
    T = idx.size(0)
    T_new = T + max_new_tokens
    if max_seq_length is None:
        max_seq_length = min(T_new, model.config.block_size)

    device, dtype = idx.device, idx.dtype
    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(T_new, dtype=dtype, device=device)
    empty[:T] = idx
    idx = empty
    input_pos = torch.arange(0, T, device=device)

    if idx.device.type == "xla":
        import torch_xla.core.xla_model as xm

        xm.mark_step()

    # generate max_new_tokens tokens
    for _ in range(max_new_tokens):
        x = idx.index_select(0, input_pos).view(1, -1)

        # forward
        logits = model(x, max_seq_length, input_pos)
        logits = logits[0, -1] / temperature

        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

        # advance
        input_pos = input_pos[-1:] + 1

        if idx.device.type == "xla":
            xm.mark_step()

        # concatenate the new generation
        idx = idx.index_copy(0, input_pos, idx_next)

        # if <eos> token is triggered, return the output (stop generation)
        if idx_next == eos_id:
            return idx[:input_pos]  # include the EOS token

    return idx


def main(
    prompt: str = "Hello, my name is",
    *,
    num_samples: int = 1,
    max_new_tokens: int = 50,
    top_k: int = 200,
    temperature: float = 0.8,
    checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
    tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
    quantize: Optional[str] = None,
) -> None:
    """Generates text samples based on a pre-trained LLaMA model and tokenizer.

    Args:
        prompt: The prompt string to use for generating the samples.
        num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)
        max_new_tokens: The number of generation steps to take.(number of generate tokens )
        top_k: The number of top most probable tokens to consider in the sampling process.
        temperature: A value controlling the randomness of the sampling process. Higher values result in more random
            samples.
        checkpoint_path: The checkpoint path to load.
        tokenizer_path: The tokenizer path to load.
        quantize: Whether to quantize the model and using which method:
            ``"llm.int8"``: LLM.int8() mode,
            ``"gptq.int4"``: GPTQ 4-bit mode.
    """
    assert checkpoint_path.is_file(), checkpoint_path
    assert tokenizer_path.is_file(), tokenizer_path

    precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
    fabric = L.Fabric(devices=1, precision=precision)

    print("Loading model ...", file=sys.stderr)
    t0 = time.time()
    with lazy_load(checkpoint_path) as checkpoint:
        name = llama_model_lookup(checkpoint)

        with fabric.init_module(empty_init=True), quantization(mode=quantize):
            model = LLaMA.from_name(name)

        model.load_state_dict(checkpoint)
    print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

    model.eval()
    model = fabric.setup(model)

    tokenizer = Tokenizer(tokenizer_path)
    encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
    prompt_length = encoded.size(0)

    L.seed_everything(1234)
    for i in range(num_samples):
        t0 = time.perf_counter()
        y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
        t = time.perf_counter() - t0

        model.reset_cache()
        print(tokenizer.decode(y))
        tokens_generated = y.size(0) - prompt_length
        print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
    if fabric.device.type == "cuda":
        print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)


if __name__ == "__main__":
    from jsonargparse import CLI

    torch.set_float32_matmul_precision("high")
    warnings.filterwarnings(
        # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
        "ignore", 
        message="ComplexHalf support is experimental and many operators don't support it yet"
    )
    warnings.filterwarnings(
        # Triggered in bitsandbytes/autograd/_functions.py:298
        "ignore", 
        message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
    )
    CLI(main)

main()

"""Generates text samples based on a pre-trained LLaMA model and tokenizer.

Args:
    prompt: The prompt string to use for generating the samples.
    num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)
    max_new_tokens: The number of generation steps to take.(number of generate tokens )
    top_k: The number of top most probable tokens to consider in the sampling process.
    temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples.
    checkpoint_path: The checkpoint path to load.
    tokenizer_path: The tokenizer path to load.
    quantize: Whether to quantize the model and using which method:
        ``"llm.int8"``: LLM.int8() mode,
        ``"gptq.int4"``: GPTQ 4-bit mode.
"""


https://zhuanlan.zhihu.com/p/657886517

Fabric()

r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
    Fabric 加速你的 PyTorch 训练或推理代码,所需的更改最小。
    
    - Automatic placement of models and data onto the device.
    - 自动将模型和数据放置到设备上。
    
    - Automatic support for mixed and double precision (smaller memory footprint).
    - 自动支持混合精度和双精度(较小的内存占用)。
    
    - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
      (data-parallel training, sharded training, etc.).
    - 在硬件(CPU、GPU、TPU)和分布式训练策略(数据并行训练、分片训练等)之间无缝切换。
    
    - Automated spawning of processes, no launch utilities required.
    - 自动生成进程,无需启动工具。
    
    - Multi-node support.
    - 支持多节点训练。

    Args:
        accelerator: The hardware to run on. Possible choices are:
            ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
        accelerator: 运行的硬件。可能的选择有:
            ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``。
        
        strategy: Strategy for how to run across multiple devices. Possible choices are:
            ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
        strategy: 跨多个设备运行的策略。可能的选择有:
            ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``。
        
        devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
            The value applies per node.
        devices: 训练时使用的设备数量(``int``),或要训练的 GPU(``list`` 或 ``str``),或 ``"auto"``。
            该值适用于每个节点。
        
        num_nodes: Number of GPU nodes for distributed training.
        num_nodes: 用于分布式训练的 GPU 节点数量。
        
        precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
            or bfloat16 precision AMP (``"bf16-mixed"``).
        precision: 双精度(``"64"``),全精度(``"32"``),半精度 AMP(``"16-mixed"``),
            或 bfloat16 精度 AMP(``"bf16-mixed"``)。
        
        plugins: One or several custom plugins
        plugins: 一个或多个自定义插件
        
        callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
            can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
        callbacks: 单个回调或回调列表。回调可以包含任何用户可以通过 :meth:`~lightning.fabric.fabric.Fabric.call` 调用的任意方法。
        
        loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
            information.
        loggers: 单个日志记录器或日志记录器列表。有关更多信息,请参见 :meth:`~lightning.fabric.fabric.Fabric.log`。
"""

lazy_load()

定义了一个名为 lazy_load 的类,它用于延迟加载和管理一个 PyTorch 文件:

lazy_load 类
__init__ 方法
python
def __init__(self, fn):
    self.zf = torch._C.PyTorchFileReader(str(fn))
    with BytesIO(self.zf.get_record("data.pkl")) as pkl:
        mup = LazyLoadingUnpickler(pkl, self)
        self.sd = mup.load()
self.zf = torch._C.PyTorchFileReader(str(fn)):

创建一个 PyTorchFileReader 实例,用于读取指定文件 (fn) 的内容。这个文件是 PyTorch 保存的文件,通常是 .pt 或 .pth 文件。
str(fn) 确保文件路径被正确转换为字符串。
with BytesIO(self.zf.get_record("data.pkl")) as pkl::

从 PyTorchFileReader 中提取名为 "data.pkl" 的记录,并用 BytesIO 创建一个内存中的字节流对象 pkl。
BytesIO 用于在内存中读写二进制数据。
mup = LazyLoadingUnpickler(pkl, self):

创建一个 LazyLoadingUnpickler 实例 mup,它负责处理 pkl 中的数据。这里假设 LazyLoadingUnpickler 是自定义的类,用于延迟加载和解码 Pickle 数据。
self.sd = mup.load():

调用 mup.load() 方法来加载数据,并将结果存储在 self.sd 属性中。这个过程可能会涉及到数据的反序列化。
__enter__ 方法
python
def __enter__(self):
    return self.sd
这个方法允许 lazy_load 实例在上下文管理器(with 语句)中使用。__enter__ 返回 self.sd,使得 with 语句块内部可以直接访问加载的数据。
__exit__ 方法
python
def __exit__(self, exc_type, exc_val, exc_tb):
    del self.zf  # I don't think there is a way to force closing...
    self.zf = None
这个方法用于处理退出上下文管理器时的清理工作。
del self.zf: 尝试删除 self.zf 对象。由于 self.zf 是一个 PyTorchFileReader 实例,删除对象的作用是释放相关资源。
self.zf = None: 另一种释放资源的方式,将 self.zf 设置为 None,以确保它不再被引用。
总结
这个类的设计用于懒加载 PyTorch 文件中的数据。它实现了上下文管理协议,使得数据可以在 with 语句块中方便地访问,并且在退出时尝试释放相关资源。

LazyLoadingUnpickler()

定义了一个 LazyLoadingUnpickler 类,继承自 pickle.Unpickler,用于处理 PyTorch 对象的延迟加载。以下是对每个部分的详细解释:

__init__ 方法
python
def __init__(self, file, zipfile_context):
    super().__init__(file)
    self.zipfile_context = zipfile_context
file: 传入的文件对象(通常是一个字节流),用于反序列化。
zipfile_context: 额外的上下文信息,用于延迟加载的实现。这通常是一个包含 PyTorch 文件读取信息的对象。
super().__init__(file): 调用父类 pickle.Unpickler 的初始化方法,传入文件对象。
self.zipfile_context: 保存额外的上下文信息,用于稍后延迟加载。
find_class 方法
python
def find_class(self, module, name):
    res = super().find_class(module, name)
    if module == "torch._utils" and name == "_rebuild_tensor_v2":
        return functools.partial(
            NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self
        )
    elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
        return functools.partial(
            NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self
        )
    elif module == "torch._utils" and name == "_rebuild_parameter":
        return functools.partial(
            NotYetLoadedTensor.rebuild_parameter, archiveinfo=self
        )
    return res
super().find_class(module, name): 调用父类的 find_class 方法,查找并返回指定模块和类名的类。
模块和类名检查:
当模块是 "torch._utils" 且类名是 "_rebuild_tensor_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_tensor_v2 方法,并传入 archiveinfo=self。
当模块是 "torch._tensor" 且类名是 "_rebuild_from_type_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_from_type_v2 方法。
当模块是 "torch._utils" 且类名是 "_rebuild_parameter" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_parameter 方法。
functools.partial: 允许创建一个新的函数,其中一些参数已经预先指定,这里是为了在实际调用时延迟具体的处理逻辑。
返回值: 如果模块和类名不匹配,返回父类的结果。
persistent_load 方法
python
def persistent_load(self, pid):
    name, cls, fn, device, size = pid
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
    s.archiveinfo = pid
    return s
pid: 一个包含多个信息的元组 (name, cls, fn, device, size),用于标识持久化数据的加载信息。
warnings.catch_warnings(): 捕获并管理警告信息。
warnings.simplefilter("ignore"): 忽略警告信息,以便在加载过程中不会产生干扰。
torch.storage.TypedStorage(dtype=cls().dtype, device="meta"): 创建一个 TypedStorage 对象,指定数据类型和设备。device="meta" 表示数据存储在元数据设备中,实际上并没有分配真实的存储空间。
s.archiveinfo = pid: 将持久化标识信息存储到 TypedStorage 对象中。
返回值: 返回创建的 TypedStorage 对象。
总结
LazyLoadingUnpickler 主要用于在反序列化 PyTorch 对象时实现延迟加载。这种方法使得在加载大数据文件时可以更高效地管理内存和计算资源。find_class 方法用于动态创建用于延迟加载的对象,而 persistent_load 方法则用于处理持久化存储数据的加载。

llama_model_lookup() 

init_module() 

def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
    """Instantiate the model and its parameters under this context manager to reduce peak memory usage.
在这个上下文管理器下实例化模型及其参数,以减少峰值内存使用。

The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.
参数会直接在设备上创建,并且使用正确的数据类型,从而避免了不必要的内存分配浪费。

Args:
参数:

empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。

If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。

Set this to ``True`` if you are loading a checkpoint into a large model.
如果你正在将检查点加载到大型模型中,将其设置为``True``。
    """
    self._validate_launched()
    return self._strategy.module_init_context(empty_init=empty_init)
module_init_context()  
 def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
        """A context manager wrapping the model instantiation.
一个包装模型实例化的上下文管理器。

Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
在这里,策略可以控制模型参数的创建方式(设备、数据类型)或对模型应用其他修补。

Args:
参数:

empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。

If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。
        """
        precision_module_ctx = self.precision.module_init_context()
        stack = ExitStack()
        stack.enter_context(self.root_device)
        stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
        stack.enter_context(precision_module_ctx)
        return stack

quantization() 

@contextmanager
def quantization(mode: str = None):
    quantized_linear_cls = None
    if mode == 'llm.int8':
        from .quantization import Linear8bitLt
        quantized_linear_cls = Linear8bitLt
    elif mode == 'gptq.int4':
        from .quantization import ColBlockQuantizedLinear
        quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
    elif mode == 'gptq.int8':
        from .quantization import ColBlockQuantizedLinear
        quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
    elif mode is not None:
        raise ValueError(f"Unknown quantization mode: {mode}")

    enabled = mode is not None
    torch_linear_cls = torch.nn.Linear
    if enabled:
        torch.nn.Linear = quantized_linear_cls
    yield
    if enabled:
        torch.nn.Linear = torch_linear_cls

model 

setup() 

    def setup(
        self,
        module: nn.Module,
        *optimizers: Optimizer,
        move_to_device: bool = True,
        _reapply_compile: bool = True,
    ) -> Any:  # no specific return because the way we want our API to look does not play well with mypy
        r"""Set up a model and its optimizers for accelerated training.
为加速训练设置模型及其优化器。

Args:
参数:

module: A :class:`torch.nn.Module` to set up
module: 要设置的 :class:`torch.nn.Module`

*optimizers: The optimizer(s) to set up (no optimizers is also possible)
*optimizers: 要设置的优化器(也可以不设置优化器)

move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
move_to_device: 如果设置为``True``(默认值),则将模型移动到正确的设备。设置为``False`` 
    and alternatively use :meth:`to_device` manually.
    并可以手动使用 :meth:`to_device`。

_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
_reapply_compile: 如果``True``(默认值),且模型之前已``torch.compile``,则
    corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
    相应的 :class:`~torch._dynamo.OptimizedModule` 包装器将被移除,并在模型被策略设置好后重新应用
    same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
    相同的设置(例如,模型被 DDP、FSDP 等包装之后)。如果编译 DDP/FSDP 造成问题,设置为``False``。

Returns:
返回:

The tuple containing wrapped module and the optimizers, in the same order they were passed in.
一个包含包装的模块和优化器的元组,顺序与传入时相同。

        """

tokenizer

 

    def encode(
        self,
        string: str,
        bos: bool = True,
        eos: bool = False,
        max_length: int = -1,
        pad: bool = False,
        device: Optional[torch.device] = None
    ) -> torch.Tensor:
        tokens = self.processor.encode(string)
        if bos:
            tokens = [self.bos_id] + tokens
        if eos:
            tokens = tokens + [self.eos_id]
        if max_length > 0:
            tokens = tokens[:max_length]
        if pad and len(tokens) < max_length:
            tokens += [self.pad_id] * (max_length - len(tokens))

        return torch.tensor(tokens, dtype=torch.int, device=device)

    def decode(self, tokens: torch.Tensor) -> str:
        return self.processor.decode(tokens.tolist())

 generate()

@torch.no_grad()
def generate(
    model: LLaMA,
    idx: torch.Tensor,
    max_new_tokens: int,
    *,
    max_seq_length: Optional[int] = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
接收一个条件序列(提示)作为输入,并继续生成所请求的数量的标记。

The implementation of this function is modified from A. Karpathy's nanoGPT.
此函数的实现改编自 A. Karpathy 的 nanoGPT。

Args:
参数:

model: The model to use.
model: 要使用的模型。

idx: Tensor of shape (T) with indices of the prompt sequence.
idx: 形状为 (T) 的张量,其中包含提示序列的索引。

max_new_tokens: The number of new tokens to generate.
max_new_tokens: 要生成的新分词数量。

max_seq_length: The maximum sequence length allowed.
max_seq_length: 允许的最大序列长度。

temperature: Scales the predicted logits by 1 / temperature
temperature: 通过 1 / temperature 对预测的 logits 进行缩放。

top_k: If specified, only sample among the tokens with the k highest probabilities
top_k: 如果指定,只从概率最高的 k 个标记中进行采样。

eos_id: If specified, stop generating any more token once the <eos> token is triggered
eos_id: 如果指定,一旦触发 <eos> 标记,停止生成更多标记。

    """

 https://pytorch.ac.cn/xla/release/2.1/index.htmlXLA 设备上的 PyTorch

model

    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
        return build_rope_cache(
            seq_len=self.config.block_size,
            n_elem=self.config.n_embd // self.config.n_head,
            dtype=idx.dtype,
            device=idx.device,
        )

temperature

温度越低,结果的差距越大,会使概率分布更加尖锐,从而使得模型更倾向于选择最高概率的类别。

topk()  

def topk(input: Tensor, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.topk: 
    r"""
    topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

返回给定 input 张量在指定维度上最大的 k 个元素。

如果没有给定 dim,则选择 input 张量的最后一个维度。

如果 largest 设置为 False,则返回 k 个最小元素。

函数返回一个命名元组 (values, indices),其中 values 和 indices 分别是输入张量在指定维度 dim 上最大的 k 个元素及其索引。

布尔选项 sorted 如果为 True,则确保返回的 k 个元素按顺序排列。

参数:

input (Tensor): 输入张量。
k (int): "top-k" 中的 k 值。
dim (int, optional): 排序的维度。
largest (bool, optional): 控制是否返回最大还是最小元素。
sorted (bool, optional): 控制是否返回排序后的元素。
关键字参数:

out (tuple, optional): 可选的输出元组 (Tensor, LongTensor),可以作为输出缓冲区使用。
示例:

python
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
    """

torch.multinomial

def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: 
    r"""
    def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor:
    r"""
    multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor
    
    返回一个张量,其中每一行包含 :attr:`num_samples` 个从对应行的多项分布中采样的索引。
    更严格地说,是从多元分布中采样,更多细节请参考 torch.distributions.multinomial.Multinomial。
    
    .. note::
        :attr:`input` 的行不需要和为 1(在这种情况下,我们使用值作为权重),
        但必须是非负的、有穷的,并且和不为零。
    
    索引按从左到右的顺序排列,依据每个索引被采样的顺序(第一个样本放在第一列)。
    
    如果 :attr:`input` 是一个向量,:attr:`out` 是一个大小为 :attr:`num_samples` 的向量。
    
    如果 :attr:`input` 是一个有 `m` 行的矩阵,则 :attr:`out` 是一个形状为
    :math:`(m \times \text{num\_samples})` 的矩阵。
    
    如果 `replacement` 为 ``True``,则样本是有放回的。
    
    如果不是,则样本是无放回的,这意味着一旦为某行绘制了一个样本索引,
    在该行中不能再次绘制相同的索引。
    
    .. note::
        当无放回采样时,:attr:`num_samples` 必须小于 :attr:`input` 中非零元素的数量
        (如果 `input` 是矩阵,则为每行的非零元素的最小数量)。
    
    Args:
        input (Tensor): 包含概率的输入张量
        num_samples (int): 要绘制的样本数量
        replacement (bool, optional): 是否允许重复抽样
    
    关键字参数:
        generator (:class:`torch.Generator`, optional): 用于采样的伪随机数生成器
        out (Tensor, optional): 输出张量。
    
    示例::
    
        >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # 创建一个权重张量
        >>> torch.multinomial(weights, 2)
        tensor([1, 2])
        >>> torch.multinomial(weights, 4) # 错误!
        RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
        not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
        >>> torch.multinomial(weights, 4, replacement=True)
        tensor([ 2,  1,  1,  1])
    """
    """

 model.reset_cache()

 

Pytorch清空显存缓冲区(torch.cuda.empty_cache)_pytorch 清空显存-CSDN博客 
Pytorch 如何在使用模型后清除GPU内存|极客教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot,maven多模块开发,子模块获取不到父模块添加的依赖,有多个root模块问题解决

错误示范 我以为放进去然后重载一下就是子模块了 导致后续在外层加的依赖&#xff0c;其article都接收不到 解决方案 需要在父模块的modules注册子模块 修改前后对比 此时子模块也能获取父模块的依赖

DDD设计方法-2-聚合、实体、值对象

前情提要&#xff1a;一共包含 如下六篇文章&#xff08;篇幅精简&#xff0c;快速入门&#xff09; 1、初识DDD 2、聚合、实体、值对象 3、仓储&#xff0c;封装持久化数据 4、端口和适配器 5、领域事件 6、领域服务&#xff0c;实现约定 DDD设计方法-2-聚合、实体、值对象&a…

基于mspm0g3507的智能送药小车(21年电赛f题,openmv寻迹,k210数字识别,并行pid调制)项目实验报告

2024年全国大学生电子设计竞赛&#xff08;TI杯&#xff09; 2024年7月17日 摘要&#xff1a;本项目由微处理器MSPM0G3507&#xff0c;编码器电机驱动&#xff0c;OPENMV、K210视觉处理单元&#xff0c;红外药品检测单元&#xff0c;ZIGBEE无限透传单元&#xff0c;OLED显示&am…

Docker数据卷和Dockerfile

1、什么是Docker数据卷 前言&#xff1a; 在下载的镜像中&#xff0c;我们不能够去改变它内部的一些配置&#xff0c;因为docker的镜像文件是已经配置好的&#xff0c;无法改变&#xff0c;我们只能改变镜像启动后的容器里面的内容&#xff0c;但是又因为&#xff0c;容器本来…

Java框架第四课(对Spring的补充Spring web)

目录 一.Spring web的认识 (1)Spring Web概念 (2)Spring web的特点 (3)Springweb运行的流程 (4)Springweb运行的流程图 二.搭建Spring web 三.自定义处理器类搭建 (1)处理器类配置 (2)处理器类接受请求 (3)获得请求数据 四.拦截器 (1)关于拦截器&#xff1a; (2)拦截器的…

【VMware】麒麟系统网络连接配置

在VMware配置页面点击编辑&#xff0c;进入虚拟网络编辑器将默认的 VMnet0删除&#xff0c;新建网络&#xff0c;设置桥接模式为Intel 打开主机cmd,查看主机IP地址&#xff0c;获取子网掩码&#xff0c;默认网关及DNS服务器 4.在主机寻找可用IP地址&#xff0c;ping不通的为未…

探秘发酵过程:酵母菌如何为白酒赋予不同风味?

在白酒酿造的神秘世界里&#xff0c;发酵过程如同一位隐形的艺术家&#xff0c;用其不同的笔触为白酒勾勒出千变万化的风味。而在这背后&#xff0c;酵母菌作为发酵的主角&#xff0c;发挥着至关重要的作用。今天&#xff0c;就让我们一起探秘发酵过程&#xff0c;了解酵母菌如…

shell 学习笔记:变量、字符串、注释

目录 1. 变量 1.1 定义使用变量 1.2 变量命名规则 1.3 只读变量 1.4 删除变量 1.5 变量类型 1.5.1 字符串变量 1.5.2 整数变量 1.5.3 数组变量 1.5.3.1 整数索引数组 1.5.3.2 关联数组 1.4 环境变量 1.5 特殊变量 2. 字符串 2.1 单引号字符串 2.2 双引…

erlang学习:用OTP构建系统23.12练习题

练习要求 制作一个名为prime_tester_server的gen_server&#xff0c;让它测试给定的数字是否是质数。 你可以使用lib_primes.erl里的is_prime/2函数来处理&#xff08;或者自己实现一个更好的质数测试函 数&#xff09;。把它添加到sellaprime_supervisor.erl的监控树里。 质…

图论(2)

一、度 度统计的是一个节点上又多少条边 度出度入度 出度&#xff1a;统计以该节点为起始点箭头指向外面的边的条数 入度&#xff1a;统计箭头指向该节点的边数 度为1的节点为悬挂节点&#xff0c;边为悬挂边 用矩阵计算节点的度 二、握手定理 比如这里第一个集合里面有三…

ARP协议(原理,特点,报文格式,具体过程),ARP缓存(有效时间,为什么),ARP欺骗(定向断网,成为中间人),RARP简单介绍

目录 ARP协议 引入 介绍 原理 arp请求/响应 特点 报文格式 硬件类型 协议类型 硬件/协议地址长度 op(操作码) 过程 发送请求并处理 返回响应并处理 总结 arp缓存 介绍 arp表项的有效时间 解释 arp欺骗 介绍 定向断网 基于arp的成为中间人的方式 多向…

跟李沐学AI:序列模型

目录 序列数据 自回归模型 马尔可夫假设 潜变量模型 序列模型总结 序列数据 实际中很多数据是时序结构的&#xff0c;如&#xff1a;电影的评价随时间的变化而变化&#xff1a;拿奖后评分上升、电影整体质量提升&#xff0c;人们要求变高。。。等等 除此之外&#xff0c;音…

比特币网络和支付

1. 比特币网络 比特币网络是一个去中心化的点对点网络&#xff0c;节点之间可以直接进行交易。网络上有不同类型的节点。 1.1 比特币网络的节点 比特币网络的节点有两种主要类型&#xff1a;全节点也称为完整节点和简单支付验证&#xff08;Simple Payment Verification,SPV)节…

档案|基于SprinBoot+vue的档案管理系统(源码+数据库+文档)

档案管理系统 基于SprinBootvue的档案管理系统 一、前言 二、系统设计 三、系统功能设计 管理员功能模块实现 学生功能模块实现 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农…

【数据库|第11期】深入掌握 SQL Server、Access 与 SQLite 中的 `UNION` 与 `UNION ALL`:从理论到实践

日期&#xff1a;2024年9月3日 作者&#xff1a;Commas 签名&#xff1a;(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释&#xff1a;如果您觉得有所帮助&#xff0c;帮忙点个赞&#xff0c;也可以关注我&#xff0c;我们一起成长&#xff1b;如果有不对的地方&#xff…

EMC整改问题

定位问题: 1.控制变量比较法:连和不连&#xff0c;接和不接来判断 2.频率判断法:低频一般是电源&#xff0c;高频一般是信号或者无线通信问题&#xff0c;还有倍频问题 3.解决方法: a.加器件&#xff0c;滤波&#xff0c;EMI共模电感&#xff0c;磁环 b.电源&#xff0c;高速信…

App推广新篇章:Xinstall带你走出数据迷雾,实现高效推广!

在如今的移动互联网时代&#xff0c;App推广已成为每个应用开发者必须面对的重要课题。然而&#xff0c;推广过程中往往伴随着诸多痛点&#xff0c;如数据混乱、投放盲目、决策滞后以及作弊困扰等。这些问题不仅影响了推广效果&#xff0c;还可能导致资源的浪费和投入产出不均衡…

Java版本的扫雷游戏程序

一、开发环境 开发工具:eclipse2021-12 JDK版本:JDK15.0.1 二、运行效果展示 这张图是游戏刚开始的画面,重置以后也是这个画面 此图是写代码的过程调试用的画面,方便查找问题。 此图是运行过程中的图片

实习的一点回顾Webhook的执行

1.Webhook流程 1.Bass外的部分 比如我通过控制台或者js脚本去调用curl命令call指定的webhook的地址的功能脚本 命令发送到网关&#xff0c;网关通过注册中心之类的发送到服务实体上。 这些是微服务的东西 2.OpenAPI到Controller阶段 先看之前openAPI的那篇前置 请求进来之…

Anaconda的环境管理操作命令详解-学习篇

一、通过命令方式管理环境 1. 查看环境 使用以下命令查看当前所有环境的命令conda env list可以看到目前电脑的base环境情况&#xff0c;我的本机只有一个base环境。是anaconda3在安装的时候所选的根目录信息。命令前的(base) 代表目前执行处于base环境&#xff0c;* 代表目前…