1 问题描述
按照官方的写法
import torch
from transformers import pipeline
import os
os.environ["HF_TOKEN"] = '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'
model_id = "meta-llama/Llama-3.2-3B"
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
pipe("The key to life is")
2 解决方法
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
禁用 PyTorch 2.x 中默认启用的 Flash Attention 和 Memory-Efficient Attention 内核