大模型持续学习方案解析:灾难性遗忘的工业级解决方案

news2025/4/6 13:27:10

引言

随着大型语言模型(LLMs)如 GPT 系列、BERT 等在自然语言处理领域取得突破性进展,它们强大的理解和生成能力已经渗透到各行各业。然而,这些模型通常是在海量静态数据集上进行一次性预训练的。现实世界是动态变化的,新的知识、事件、术语层出不穷。如何让这些大模型能够像人一样,不断学习新知识,同时不忘记已经掌握的旧知识?这就是持续学习(Continual Learning, CL)领域的核心议题。

其中,灾难性遗忘(Catastrophic Forgetting)是持续学习面临的最大挑战——模型在学习新任务或新数据时,往往会显著降低在旧任务上的表现,仿佛“忘掉”了之前学过的内容。这对于需要模型保持长期服务、不断适应新环境的工业应用来说是致命的。

本文将深入探讨大模型持续学习的必要性,解析灾难性遗忘的根源,并重点介绍几种主流的、具有工业应用潜力的解决方案,结合理论分析和代码实践,为大家提供一份保姆级的指南。

一、 为什么大模型需要持续学习?

想象一下,一个部署在客服系统中的大模型,如果不能学习最新的产品信息、政策变动,它的价值会迅速下降。或者一个用于新闻推荐的模型,无法理解新近发生的重大事件,其推荐效果必然大打折扣。

大模型需要持续学习的主要原因包括:

  1. 知识更新:世界是动态的,新的事实、概念、事件不断涌现。
  2. 个性化需求:针对特定用户群体或特定领域(如医疗、法律)进行微调和知识增强。
  3. 适应性与鲁棒性:适应数据分布的变化(Domain Shift),提高模型在不同环境下的表现。
  4. 效率与成本:相比于完全重新训练(成本极高),持续学习提供了一种更经济、高效的模型迭代方式。

二、 灾难性遗忘:持续学习的“拦路虎”

什么是灾难性遗忘?

当一个神经网络模型(尤其是深度模型)顺序地学习一系列任务(Task 1, Task 2, ..., Task N)时,在学习新任务(如 Task k)的过程中,模型参数为了适应新任务而被修改,这些修改可能会严重破坏模型在旧任务(Task 1 to Task k-1)上学到的知识,导致其性能急剧下降。

为什么会发生?

神经网络的参数(权重)在学习过程中是共享和高度耦合的。当模型优化器(如 Adam、SGD)根据新任务的损失函数梯度更新参数时,它并没有机制去“保护”那些对旧任务至关重要的参数。新任务的梯度可能会将参数推向一个新的区域,这个区域虽然在新任务上表现良好,但在旧任务上却很糟糕。对于参数量巨大的 LLMs 来说,这个问题尤为突出。

三、对抗灾难性遗忘:主流策略通俗解析

1. 回放法:边学新边复习旧知识

核心思想:像学生复习笔记一样,在学新内容时不断回顾旧知识。

具体做法

  • 经验回放(ER)

    • 维护一个 “记忆库”,保存旧任务的典型样本(比如之前学过的图片和对应标签)。
    • 训练新任务时,每次随机抽取少量旧样本,和新数据一起训练模型,强制模型记住旧知识。
    • 样本选择策略
      • FIFO:先存的先被遗忘,可能丢失重要内容。
      • 随机替换:保持多样性,但可能漏掉关键样本。
      • 按错误率优先:优先复习模型容易错的样本。
    • 优缺点
      • ✅ 效果好,直接用真实数据复习。
      • ❌ 存储大量数据(尤其图像、文本)成本高,隐私风险大。
  • 生成式回放(GR)

    • 训练一个 “造假工厂”(如 GAN),专门生成类似旧数据的 “假货”。
    • 用生成的假数据代替真实旧数据进行复习。
    • 优缺点
      • ✅ 省存储,保护隐私。
      • ❌ 假货质量差会拖累效果,造假工厂本身也可能 “失忆”。

2. 正则化法:给重要参数加保护罩

核心思想:给旧知识的关键参数套上 “枷锁”,防止学习新任务时被过度修改。

具体做法

  • 弹性权重巩固(EWC)

    • 用 “重要性打分” 标记参数对旧任务的重要性(比如参数变化对旧任务影响越大,分数越高)。
    • 学习新任务时,总损失 = 新任务损失 + 重要性分数 × (新参数 - 旧参数)^2。
    • 优缺点
      • ✅ 无需存储旧数据,理论清晰。
      • ❌ 计算复杂度高,可能过度保护旧参数,导致新任务学不好。
  • 突触智能(SI)

    • 实时记录每个参数在训练中的 “贡献值”,动态调整保护强度。
    • 总损失 = 新任务损失 + 贡献值 × (参数变化)^2。
    • 优缺点
      • ✅ 动态计算,无需额外存储。
      • ❌ 理论不够直观,依赖训练路径。
  • 无遗忘学习(LwF)

    • 让旧模型(已冻结)和新模型同时 “做题”,新模型既要答对新题,也要模仿旧模型的思路。
    • 总损失 = 新任务正确率损失 + 模仿旧模型的 “教学损失”。
    • 优缺点
      • ✅ 简单易行,无需存储数据或计算参数重要性。
      • ❌ 旧模型可能在新任务上表现差,导致教学效果下降。

3. 架构法:给不同任务分配专属 “房间”

核心思想:为每个任务单独划分模型资源,避免不同任务的知识互相干扰。

具体做法

  • 适配器(Adapters)

    • 在大模型中插入小型 “插件”(类似电脑外接设备),只训练插件参数,基础模型保持不变。
    • 优缺点
      • ✅ 省内存,训练快(只改 1-5% 参数)。
      • ❌ 插件能力有限,可能不如完全重训。
  • 低秩适应(LoRA)

    • 给大模型参数矩阵添加 “补丁”(低秩分解矩阵),只训练补丁参数。
    • 优缺点
      • ✅ 比适配器更省参数,效果接近全量训练。
      • ❌ 需调整 “补丁” 大小(秩)。
  • 渐进式网络(PNNs)

    • 每学一个新任务,就给模型 “加一层楼”,每层楼独立处理对应任务,旧层完全冻结。
    • 优缺点
      • ✅ 完全不遗忘。
      • ❌ 模型体积随任务数爆炸式增长。

4、策略选择建议

  • 数据充足且无隐私限制 → 选回放法(效果最好)。
  • 追求轻量化部署 → 选 LoRA / 适配器(省资源)。
  • 任务差异极大 → 考虑架构法(隔离知识)。
  • 快速实验验证 → 先用 LwF(实现简单)。

四、 方案选择

在资源有限、效率至上、模型需长期服务的工业环境中,结合参数高效微调(PEFT)技术(尤其是 LoRA)和持续学习(CL)策略,是当前极具吸引力的方向。下面提供几个详细且可操作的方案:

方案一:独立 LoRA 适配器 + 任务路由

适用场景: 任务之间差异较大,需要严格隔离;或者运维简单性优先,允许为每个任务维护独立适配器。

核心思想: 为每个新任务训练一套独立的 LoRA 权重,推理时根据任务标识加载对应的适配器。

实施步骤:

  1. 基础模型: 选定一个预训练好的大模型(如 BERT, GPT, LLaMA 等)作为基础,并始终保持其主体参数冻结
  2. 任务 1 训练:
    • 定义 LoRA 配置 (LoraConfig),指定秩 r, alpha, target_modules 等。
    • 使用 get_peft_model 将 LoRA 应用到基础模型。
    • 在任务 1 数据上训练,只优化 LoRA 参数 (B 和 A 矩阵)
    • 训练完成后,使用 model.save_pretrained("./adapter_task1") 保存适配器权重。适配器文件夹通常只包含 adapter_model.binadapter_config.json,非常小。
  3. 任务 2 训练:
    • 重新加载原始的基础模型(确保是干净、冻结的状态)。
    • 定义(或复用)LoRA 配置。
    • 使用 get_peft_model 应用新的 LoRA 层。
    • 在任务 2 数据上训练,只优化新的 LoRA 参数。
    • 保存任务 2 适配器:model.save_pretrained("./adapter_task2")
  4. 后续任务: 重复步骤 3。
  5. 推理与部署:
    • 加载基础模型。
    • 根据需要执行的任务(例如,通过 API 请求中的任务标识符),使用 PeftModel.from_pretrained(base_model, "./adapter_task_k") 加载相应的适配器。
    • 执行推理。需要切换任务时,只需加载不同的适配器即可,基础模型不变。

管理: 需要维护一个映射关系(如字典或数据库表),将任务 ID/名称映射到其对应的适配器存储路径。

优点:

  • 强力抗遗忘: 任务知识物理隔离在不同适配器中。
  • 管理清晰: 每个任务对应一套独立权重,易于管理、版本控制和回滚。
  • 训练高效: 每次只训练少量参数。

缺点:

  • 无显式知识共享: 任务间的共性知识可能需要重复学习到各自的适配器中。
  • 适配器存储: 存储成本随任务数量线性增加(但每个适配器很小,通常可接受)。

方案二:共享基础模型 + LoRA + 轻量级经验回放

适用场景: 任务间有一定关联性,希望在隔离的同时促进知识巩固;或者需要进一步抵抗概念漂移。

核心思想: 依然为每个任务训练独立 LoRA 适配器,但在训练新任务时,混合少量来自旧任务的回放数据,以“提醒”当前适配器不要与旧知识产生冲突。

实施步骤:

  1. 基础模型与适配器训练: 同方案一,为 Task 1...N 训练并保存各自的 Adapter_1...Adapter_N
  2. 维护回放缓冲区:
    • 在训练 Task K 时,将 Task K 的部分代表性样本(如 input_ids, attention_mask, label)存入一个全局的回放缓冲区(Replay Buffer)。
    • 缓冲区管理策略:使用固定大小的水库采样,确保存储的样本来自所有历史任务且分布相对均衡。
  3. 训练 Task N (N > 1):
    • 加载基础模型,应用新的 Adapter_N
    • 数据混合: 在每个训练 step,准备两个 batch:
      • batch_N: 来自当前任务 N 的数据。
      • batch_replay: 从回放缓冲区中采样得到的、来自 Task 1...N-1 的混合数据。
    • 合并训练: 将 batch_Nbatch_replay 合并(concatenate)成一个更大的 batch。
    • 计算损失: 将合并后的 batch 输入到当前正在训练的 Adapter_N 模型中,计算标准损失(例如交叉熵)。注意,需要确保模型能够处理来自不同任务的标签(如果标签体系不同,可能需要调整输出层或在计算损失时区分)。
    • 反向传播: 基于合并 batch 的总损失,更新 Adapter_N 的参数。
    • 更新缓冲区: 将 Task N 的新样本按策略加入回放缓冲区。
  4. 推理: 同方案一,按需加载特定任务的适配器。

优点:

  • 增强鲁棒性: 回放数据有助于新适配器“意识”到旧任务的存在,减少在新数据上的过拟合,并可能轻微提升旧任务在用新适配器推理时的表现(虽然主要还是靠加载旧适配器)。
  • 保留隔离性: 主要知识仍存储在独立适配器中。

缺点:

  • 增加训练开销: 每个 step 需要处理额外回放数据,计算量增大。
  • 缓冲区管理: 需要设计和维护缓冲区大小、采样策略。
  • 调参复杂: 需要调整回放数据占比或损失权重(如果分开计算损失)。
  • 负迁移风险: 如果旧任务与新任务差异极大或回放样本选择不当,可能干扰新任务学习。

方案三:任务适配器的组合与融合

适用场景: 任务间关联性强,希望实现更灵活的知识共享和组合。

核心思想: 不仅仅是独立加载适配器,而是探索如何组合或融合多个任务的适配器来处理混合任务或实现更平滑的知识迁移。

实施方式 (概念性,实现较复杂):

  • 适配器叠加/插值: 对于一个输入,可以尝试同时加载多个相关任务的适配器,并将其输出(或 LoRA 的 ΔW)进行加权平均或更复杂的组合。权重可以基于任务相似度或元学习得到。
  • 任务向量 + Adapter: 训练一个任务嵌入向量,该向量可以调制(例如,通过 FiLM 层)一个共享的适配器或基础模型层,使得模型行为适应特定任务。
  • Adapter Merging: 研究表明,多个 LoRA 适配器的权重可以直接进行(加权)平均,合并后的适配器有时能在多个任务上取得不错的综合性能,减少需要存储的适配器数量。

优点:

  • 潜力巨大: 可能实现更细粒度的知识控制和迁移。
  • 模型复用: 合并或组合适配器可以减少部署时的模型实例/加载次数。

缺点:

  • 技术前沿: 许多组合/融合技术仍在研究阶段,鲁棒性和通用性有待验证。
  • 实现复杂: 对 PEFT 库和模型架构需要更深入的理解和定制。
  • 优化困难: 寻找最佳组合方式或权重可能非常困难。

落地建议:

  • 起步: 从方案一(独立 LoRA 适配器)开始,它最简单、鲁棒,易于实现和管理。
  • 进阶: 如果需要进一步提升性能或处理任务关联性,可以考虑方案二(+轻量级回放),但要仔细评估其带来的额外开销和调参复杂度。
  • 探索: 方案三属于前沿探索,适合有较强研发能力、对性能有极致追求的团队进行尝试。

通用考量:

  • 评测: 建立完善的持续学习评测体系至关重要。每次学习新任务后,必须评估模型在所有历史任务以及当前任务上的性能,计算平均准确率、遗忘率(Backward Transfer)、新任务学习效果(Forward Transfer)等指标。
  • 监控: 线上部署后,持续监控模型在不同任务上的表现,及时发现性能衰退或遗忘问题。
  • 数据工程: 合理的数据划分、版本管理、以及(如果使用回放)有效的缓冲区构建和采样是成功的关键。

五、 实战演练:使用 LoRA 实现简单的持续学习

下面我们提供一个使用 transformers, datasets, 和 peft 库,基于 方案一(独立 LoRA 适配器) 实现的、可以直接运行的持续学习示例。我们将使用公开数据集 imdb (情感分类,2类) 作为 Task 1,以及 ag_news (新闻主题分类,4类) 的一个子集作为 Task 2。

(请确保你已经安装了必要的库,并在有 GPU 的环境运行以获得合理速度)

pip install torch transformers datasets peft accelerate bitsandbytes numpy tqdm scikit-learn # Added sklearn for metrics
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW, get_linear_schedule_with_warmup, BitsAndBytesConfig
from datasets import load_dataset, Dataset, concatenate_datasets
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training, PeftModel
from tqdm.notebook import tqdm # 如果在 Jupyter/Colab 中,使用 notebook tqdm,否则 from tqdm import tqdm
import os
from sklearn.metrics import accuracy_score # 使用 sklearn 计算准确率

# --- 1. 配置 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 为了更快的演示和更低的资源消耗,使用一个较小的模型。
MODEL_NAME = "prajjwal1/bert-tiny"
# 如果资源允许,可以替换为 "bert-base-uncased" 或更大的模型。
# MODEL_NAME = "bert-base-uncased" # 取消注释以使用更大的模型
TASK1_NAME = "IMDB 情感分析"
TASK2_NAME = "AG News 主题分类 (子集)"
NUM_EPOCHS_PER_TASK = 2 # 可以增加轮数以获得可能更好的结果
BATCH_SIZE = 16
LEARNING_RATE = 1e-3 # PEFT 通常在较高的学习率下表现良好
MAX_LENGTH = 128
ADAPTER_SAVE_DIR = "./cl_adapters" # 适配器保存目录
TASK1_ADAPTER_PATH = os.path.join(ADAPTER_SAVE_DIR, "adapter_task1")
TASK2_ADAPTER_PATH = os.path.join(ADAPTER_SAVE_DIR, "adapter_task2")

# 可选:量化配置 (如果使用大模型且显存有限)
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

print(f"使用设备: {DEVICE}")
os.makedirs(ADAPTER_SAVE_DIR, exist_ok=True) # 确保目录存在

# --- 2. 加载 Tokenizer (两个任务共用) ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# --- 数据准备辅助函数 ---
def prepare_data(examples, text_col='text', label_col='label'):
    # 对文本进行分词
    tokenized_inputs = tokenizer(examples[text_col], padding="max_length", truncation=True, max_length=MAX_LENGTH)
    # 将标签列重命名为 'labels' 以兼容 HF Trainer/模型
    tokenized_inputs["labels"] = examples[label_col]
    return tokenized_inputs

# --- 评估辅助函数 ---
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_eval_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="评估中", leave=False):
            # 将批次数据移动到指定设备
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            total_eval_loss += loss.item()
            preds = torch.argmax(logits, dim=-1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = total_eval_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds) # 使用 sklearn 计算准确率
    print(f"  准确率: {accuracy:.4f}")
    print(f"  平均损失: {avg_val_loss:.4f}")
    return accuracy, avg_val_loss

# --- 3. 准备任务 1 数据 (IMDB) ---
print(f"\n--- 准备 {TASK1_NAME} 数据 ---")
imdb_dataset = load_dataset("imdb")
# 为了更快的演示,使用较小的数据子集
train_dataset_t1_raw = imdb_dataset['train'].shuffle(seed=42).select(range(2000)) # 2000 个训练样本
val_dataset_t1_raw = imdb_dataset['test'].shuffle(seed=42).select(range(500))   # 500 个验证样本

# 映射处理并设置格式
train_dataset_t1 = train_dataset_t1_raw.map(lambda x: prepare_data(x, text_col='text', label_col='label'), batched=True)
val_dataset_t1 = val_dataset_t1_raw.map(lambda x: prepare_data(x, text_col='text', label_col='label'), batched=True)
train_dataset_t1.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_dataset_t1.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

task1_train_dataloader = DataLoader(train_dataset_t1, batch_size=BATCH_SIZE, shuffle=True)
task1_val_dataloader = DataLoader(val_dataset_t1, batch_size=BATCH_SIZE)
NUM_LABELS_TASK1 = imdb_dataset['train'].features['label'].num_classes
print(f"任务 1: {len(train_dataset_t1)} 训练样本, {len(val_dataset_t1)} 验证样本。标签数量: {NUM_LABELS_TASK1}")

# --- 4. 准备任务 2 数据 (AG News 子集) ---
print(f"\n--- 准备 {TASK2_NAME} 数据 ---")
ag_news_dataset = load_dataset("ag_news")
# 使用较小的数据子集
train_dataset_t2_raw = ag_news_dataset['train'].shuffle(seed=42).select(range(2000)) # 2000 个训练样本
val_dataset_t2_raw = ag_news_dataset['test'].shuffle(seed=42).select(range(500))   # 500 个验证样本

# 映射处理并设置格式
train_dataset_t2 = train_dataset_t2_raw.map(lambda x: prepare_data(x, text_col='text', label_col='label'), batched=True)
val_dataset_t2 = val_dataset_t2_raw.map(lambda x: prepare_data(x, text_col='text', label_col='label'), batched=True)
train_dataset_t2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_dataset_t2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

task2_train_dataloader = DataLoader(train_dataset_t2, batch_size=BATCH_SIZE, shuffle=True)
task2_val_dataloader = DataLoader(val_dataset_t2, batch_size=BATCH_SIZE)
NUM_LABELS_TASK2 = ag_news_dataset['train'].features['label'].num_classes
print(f"任务 2: {len(train_dataset_t2)} 训练样本, {len(val_dataset_t2)} 验证样本。标签数量: {NUM_LABELS_TASK2}")


# --- 5. 定义 LoRA 配置 ---
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, # 重要:设置任务类型为序列分类
    r=8,                        # 更新矩阵的秩
    lora_alpha=16,              # Alpha 缩放参数
    lora_dropout=0.1,           # LoRA 层的 Dropout 概率
    bias="none",                # 偏置类型 ('none', 'all', 或 'lora_only')
    # PEFT 会自动为 BERT 等常见模型找到 target_modules (例如 query, value 层)
    # target_modules=["query", "value"] # 如果需要,可以显式指定目标模块
)

# --- 获取全新基础模型的函数 ---
def get_base_model(num_labels):
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels,
        # quantization_config=bnb_config, # 如果使用量化,则启用此行
    )
    # 仅在未使用 PEFT 的 prepare_model_for_kbit_training 进行量化时,
    # 或者在 PEFT 之外手动管理设备放置时才需要下面这行
    # model = model.to(DEVICE)
    return model

# --- 6. 训练任务 1 (IMDB) ---
print(f"\n--- 训练 {TASK1_NAME} ---")
# 加载为任务 1 标签配置的全新基础模型
base_model_t1 = get_base_model(NUM_LABELS_TASK1)

# 应用 LoRA
# 如果使用量化,PEFT 会在此处处理设备放置
if 'bnb_config' in locals():
     base_model_t1 = prepare_model_for_kbit_training(base_model_t1)

lora_model_t1 = get_peft_model(base_model_t1, lora_config)
lora_model_t1.print_trainable_parameters() # 打印可训练参数量,你会看到这个数字非常小
lora_model_t1.to(DEVICE) # 如果未使用量化,请确保模型位于正确的设备上

# 优化器和学习率调度器
optimizer_t1 = AdamW(lora_model_t1.parameters(), lr=LEARNING_RATE)
total_steps_t1 = len(task1_train_dataloader) * NUM_EPOCHS_PER_TASK
scheduler_t1 = get_linear_schedule_with_warmup(optimizer_t1, num_warmup_steps=0, num_training_steps=total_steps_t1)

# 任务 1 训练循环
for epoch in range(NUM_EPOCHS_PER_TASK):
    lora_model_t1.train()
    total_loss = 0
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS_PER_TASK} - 任务 1")
    for batch in tqdm(task1_train_dataloader, desc="训练任务 1"):
        lora_model_t1.zero_grad()
        # 将批次数据移动到指定设备
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        outputs = lora_model_t1(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer_t1.step()
        scheduler_t1.step()

    avg_train_loss = total_loss / len(task1_train_dataloader)
    print(f"  平均训练损失: {avg_train_loss:.4f}")
    print("  在任务 1 验证集上评估...")
    evaluate_model(lora_model_t1, task1_val_dataloader, DEVICE)

# 保存任务 1 适配器
lora_model_t1.save_pretrained(TASK1_ADAPTER_PATH)
print(f"任务 1 LoRA 适配器已保存至 {TASK1_ADAPTER_PATH}")

# --- 记录学习任务 2 之前的性能 ---
print("\n--- 评估任务 1 性能 (使用任务 1 适配器, 训练任务 2 之前) ---")
# 我们需要正确加载适配器以进行评估
eval_base_model_t1 = get_base_model(NUM_LABELS_TASK1) # 全新的基础模型
eval_lora_model_t1 = PeftModel.from_pretrained(eval_base_model_t1, TASK1_ADAPTER_PATH)
eval_lora_model_t1.to(DEVICE)
accuracy_t1_initial, _ = evaluate_model(eval_lora_model_t1, task1_val_dataloader, DEVICE)
# 清理 GPU 显存
del eval_base_model_t1, eval_lora_model_t1, lora_model_t1, base_model_t1, optimizer_t1, scheduler_t1
torch.cuda.empty_cache()


# --- 7. 训练任务 2 (AG News) ---
print(f"\n--- 训练 {TASK2_NAME} ---")
# 加载为任务 2 标签配置的全新基础模型
# 重要提示:由于标签数量不同 (IMDB 2类, AG News 4类),我们必须加载新的基础模型实例
# 或者如果原地修改,则需显式调整分类头的大小。加载全新的更安全。
base_model_t2 = get_base_model(NUM_LABELS_TASK2)

# 应用 LoRA (为任务 2 创建新的适配器)
if 'bnb_config' in locals():
     base_model_t2 = prepare_model_for_kbit_training(base_model_t2)

lora_model_t2 = get_peft_model(base_model_t2, lora_config) # 获取新的 LoRA 层
lora_model_t2.print_trainable_parameters()
lora_model_t2.to(DEVICE) # 如果未使用量化,请确保模型位于正确的设备上

# 任务 2 的优化器和学习率调度器
optimizer_t2 = AdamW(lora_model_t2.parameters(), lr=LEARNING_RATE)
total_steps_t2 = len(task2_train_dataloader) * NUM_EPOCHS_PER_TASK
scheduler_t2 = get_linear_schedule_with_warmup(optimizer_t2, num_warmup_steps=0, num_training_steps=total_steps_t2)

# 任务 2 训练循环
for epoch in range(NUM_EPOCHS_PER_TASK):
    lora_model_t2.train()
    total_loss = 0
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS_PER_TASK} - 任务 2")
    for batch in tqdm(task2_train_dataloader, desc="训练任务 2"):
        lora_model_t2.zero_grad()
        # 将批次数据移动到指定设备
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        outputs = lora_model_t2(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer_t2.step()
        scheduler_t2.step()

    avg_train_loss = total_loss / len(task2_train_dataloader)
    print(f"  平均训练损失: {avg_train_loss:.4f}")
    print("  在任务 2 验证集上评估...")
    evaluate_model(lora_model_t2, task2_val_dataloader, DEVICE)

# 保存任务 2 适配器
lora_model_t2.save_pretrained(TASK2_ADAPTER_PATH)
print(f"任务 2 LoRA 适配器已保存至 {TASK2_ADAPTER_PATH}")
# 评估最终的任务 2 性能
print("  评估最终任务 2 性能...")
accuracy_t2_final, _ = evaluate_model(lora_model_t2, task2_val_dataloader, DEVICE)

# 清理 GPU 显存
del lora_model_t2, base_model_t2, optimizer_t2, scheduler_t2
torch.cuda.empty_cache()


# --- 8. 评估任务 1 的遗忘情况 ---
print("\n--- 再次评估任务 1 性能 (使用任务 1 适配器, 训练任务 2 之后) ---")
# 加载全新的基础模型并加载 *任务 1* 的适配器
final_eval_base_model_t1 = get_base_model(NUM_LABELS_TASK1)
final_eval_lora_model_t1 = PeftModel.from_pretrained(final_eval_base_model_t1, TASK1_ADAPTER_PATH)
final_eval_lora_model_t1.to(DEVICE)

print("在任务 1 验证集上评估加载的任务 1 模型...")
accuracy_t1_final, _ = evaluate_model(final_eval_lora_model_t1, task1_val_dataloader, DEVICE)

# 清理 GPU 显存
del final_eval_base_model_t1, final_eval_lora_model_t1
torch.cuda.empty_cache()


# --- 9. 结果分析 ---
print("\n--- 持续学习性能总结 ---")
print(f"任务 1 ({TASK1_NAME}) 初始准确率: {accuracy_t1_initial:.4f}")
print(f"任务 1 ({TASK1_NAME}) 最终准确率 (使用 T1 适配器): {accuracy_t1_final:.4f}")
forgetting_t1 = accuracy_t1_initial - accuracy_t1_final
# 添加一个小的 epsilon 防止除零错误(虽然准确率不太可能为零)
epsilon = 1e-6
print(f"任务 1 遗忘率: {forgetting_t1:.4f} (相对遗忘: {forgetting_t1 / (accuracy_t1_initial + epsilon):.2%})")
print("-" * 30)
print(f"任务 2 ({TASK2_NAME}) 最终准确率 (使用 T2 适配器): {accuracy_t2_final:.4f}")
print("-" * 30)
print("\n分析:")
print("使用“独立 LoRA 适配器”策略:")
print("1. 我们成功地在冻结的基础模型之上,为两个不同的任务(IMDB 情感分析 & AG News 主题分类)训练了独立的适配器。")
print("2. 通过在训练任务 2 *之后*加载任务 1 的特定适配器,我们能够恢复任务 1 的性能。")
print(f"3. 观察到的遗忘率 ({forgetting_t1:.4f}) 预期非常低(理想情况下接近零,由于数值精度或微小环境变化可能存在细微差异)。这证明了 PEFT 通过参数隔离在缓解灾难性遗忘方面的有效性。")
print("4. 我们也使用其专用的适配器在任务 2 上取得了良好的性能。")
print("5. 对于在无需昂贵重训练或显著遗忘的情况下,为大型预训练模型添加新任务能力而言,这种方法具有高度的可扩展性和效率。")

代码说明:

  1. 真实数据集: 使用 load_dataset 加载了 imdbag_news。为了快速演示,只选取了部分数据 (select(range(...)))。你可以调整样本数量或移除 select 来使用完整数据集。
  2. 模型选择: 默认使用 prajjwal1/bert-tiny,这是一个非常小的 BERT 模型,便于快速运行和在资源有限的环境下测试。你可以取消注释 bert-base-uncased 行来使用更大的模型(需要更多时间和 VRAM)。
  3. 动态标签数: get_base_model 函数现在接受 num_labels 参数,确保为每个任务加载具有正确输出维度分类头的基础模型。这对于任务标签数不同的情况至关重要。
  4. 清晰的适配器管理: 代码明确地为 Task 1 和 Task 2 加载独立的基础模型实例并应用新的 LoRA 层,然后分别保存适配器到 TASK1_ADAPTER_PATHTASK2_ADAPTER_PATH
  5. 正确的评估流程:
    • 在训练 Task 2 之前,加载 T1 适配器评估 T1 性能 (accuracy_t1_initial)。
    • 在训练 Task 2 之后,再次加载 T1 适配器评估 T1 性能 (accuracy_t1_final),以计算遗忘。
    • 加载 T2 适配器评估 T2 性能 (accuracy_t2_final)。
  6. 依赖库: 添加了 scikit-learn 用于更方便地计算准确率 (accuracy_score)。
  7. 内存管理: 在切换任务或评估阶段之间,使用 del 删除不再需要的模型和优化器变量,并调用 torch.cuda.empty_cache() 尝试释放 GPU 显存,这对于在有限 VRAM 下运行多个阶段很重要。
  8. 注释与说明: 添加了更多注释来解释代码逻辑,特别是关于独立适配器策略、模型加载和评估步骤。
  9. Quantization (Optional): 添加了 BitsAndBytesConfigprepare_model_for_kbit_training 的注释和示例用法,如果用户想在更大模型上尝试 4/8 位量化,可以取消注释相关行。
  10. TaskType: 明确设置 LoraConfig 中的 task_type=TaskType.SEQ_CLS,这有助于 PEFT 正确配置适配器。

六、 挑战与展望

尽管持续学习已经取得了显著进展,尤其是在 PEFT 技术的加持下,但仍面临挑战:

  1. 可扩展性:如何处理成百上千个连续任务?任务越多,所需存储的 adapters 或模型变体也越多。
  2. 知识迁移与融合:当前方法更侧重于防止遗忘,如何促进任务间的正向知识迁移(学新帮旧)和知识融合仍需探索。
  3. 更复杂的场景:如开放域、无明确任务边界的持续学习。
  4. 评测标准:需要更全面、贴近实际应用的评测基准和指标。
  5. 理论理解:对遗忘和持续学习的内在机制仍需更深入的理论支撑。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2329253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

医疗思维图与数智云融合:从私有云到思维图的AI架构迭代(代码版)

医疗思维图作为AI架构演进的重要方向,其发展路径从传统云计算向融合时空智能、大模型及生态开放的“思维图”架构迭代,体现了技术与场景深度融合的趋势。 以下是其架构迭代的核心路径与关键特征分析: 一、从“智慧云”到“思维图”的架构演进逻辑 以下是针对医疗信息化领域…

【JS】接雨水题解

题目 思路 首先我们要明确如何计算每条柱子的接水量: 每条柱子对应接到的雨水量该柱子左边最大值和右边最大值中的较小值-该柱子本身的高度。举例:第二条柱子自身高度为0,左边最大值为1,右边最大值为3,取较小值1-自身…

线代[12]|《高等几何》陈绍菱(1984.9)(文末有对三大空间的分析及一个合格数学系毕业生的要求)

文章目录 一、概述二、平面仿射几何的基本概念三、平面射影几何的基本概念四、变换群和几何学五、二次曲线的射影理论、仿射理论和度量理论六、射影几何公理基础七、非欧几里得几何概要八、自我测试题九、欧氏解析几何、仿射解析几何、射影解析几何与其他(博主借助A…

第3课:状态管理与事件处理

第3课:状态管理与事件处理 学习目标 掌握useState Hook的使用理解组件事件处理机制实现表单输入与状态绑定完成任务添加功能原型 一、useState基础 1. 创建第一个状态 新建src/Counter.js: import { useState } from react;function Counter() {co…

【速写】Transformer-encoder-decoder深度解析

文章目录 一、理论分析1. Transformers概述2. Transformer的输入部分具体是如何构成?2.1 单词 Embedding2.2 位置 Embedding 3 自注意力原理3.1 自注意力结构3.2 QKV的计算3.3 自注意力的输出3.4 多头注意力 4 Encoder结构4.1 AddNorm4.2 前馈4.3 组成Encoder 二、代…

MyBatis八股文-执行流程、延迟加载、一级与二级缓存

(一)执行流程 mybatis-config.xml核心配置文件的作用: 在MyBatis框架的核心配置文件中需要去指定当前的环境配置、指定需要操作的是哪个数据库,并且输入当前的用户名与密码,只有配置了他才能真正操作数据库。同时还去加载了SQL映射文件&#…

基于Spark的哔哩哔哩舆情数据分析系统

【Spark】基于Spark的哔哩哔哩舆情数据分析系统 (完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 本项目基于Python和Django框架进行开发,为了便于广大用户针对舆情进行个性化分析处…

【Linux】日志模块实现详解

📢博客主页:https://blog.csdn.net/2301_779549673 📢博客仓库:https://gitee.com/JohnKingW/linux_test/tree/master/lesson 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! &…

Java基础:面向对象高级(四)

内部类(类中五大成分之一) 四种形式 成员内部类【了解】 静态内部类【了解】 局部内部类【了解】 匿名内部类【重点】 枚举 泛型 什么是泛型 泛型类-模拟ArrayList 泛型接口-操作学生,老师增删改查 泛型方法 泛型擦除和注意事项

easy-poi 一对多导出

1. 需求: 某一列上下两行单元格A,B值一样且这两个单元格, 前面所有列对应单元格值一样的话, 就对A,B 两个单元格进行纵向合并单元格 1. 核心思路: 先对数据集的国家,省份,城市...... id 身份证进行排序…

python通过调用海康SDK打开工业相机(全流程)

首先打开海康机器人-机器视觉-下载中心 下载最新版的 MVS 安装后打开目录找到 ...\MVS\Development\Samples\Python 将MvImport内所有文件拷贝至工作目录 然后到 C:\Program Files (x86)\Common Files\MVS\Runtime 找到适合自己系统的版本,将整个文件夹拷贝至工…

manim,制作专业的数学公式动画

manim是一个Python第三方库,全称是mathematical animation engine(数学动画引擎)。manim用于解说线性代数、微积分、神经网络、黎曼猜想、傅里叶变换以及四元数等数学概念。 manim使你能够以编程的方式创建精确的数学图形、动画和场景。与传统的几何画板等绘图软件不同,man…

小刚说C语言刷题——第15讲 多分支结构

1.多分支结构 所谓多分支结构是指在选择的时候有多种选择。根据条件满足哪个分支,就走对应分支的语句。 2.语法格式 if(条件1) 语句1; else if(条件2) 语句2; else if(条件3) 语句3; ....... else 语句n; 3.示例代码 从键盘输入三条边的长度,…

[ctfshow web入门] web6

前置知识 入口点(目录)爆破 还记得之前说过网站的入口的吗,我们输入url/xxx,其中如果url/xxx存在,那么访问成功,证明存在这样一个入口点;如果访问失败则证明不存在此入口点。所以我们可以通过遍历url/xxx,…

简单程序语言理论与编译技术·22 实现一个从AST到RISCV的编译器

本文是记录专业课“程序语言理论与编译技术”的部分笔记。 LECTURE 22(实现一个从AST到RISCV的编译器) 一、问题分析 1、完整的编译器(如LLVM)需先完成AST到IR的转换,并进行代码优化,再到汇编&#xff0…

lua和C的交互

1.C调用lua例子 #include <iostream> #include <lua.hpp>int main() {//用于创建一个新的lua虚拟机lua_State* L luaL_newstate();luaL_openlibs(L);//打开标准库/*if (luaL_dofile(L, "test.lua") ! LUA_OK) {std::cerr << "Lua error: &…

Css:如何解决绝对定位子元素内容被父级元素overflow:hidden属性剪裁

一、问题描述 今天小伙伴提了一个bug&#xff0c;在点击列表项的“…”按钮应该出现的悬浮菜单显示不完整&#xff1a; 二、问题排查 一般这种问题&#xff0c;是由于悬浮菜单采用的是绝对定位&#xff0c;而父级采用了overflow:hidden属性。但需要注意的是&#xff0c;这里的…

RoMo: Robust Motion Segmentation Improves Structure from Motion

前言 看起来像是一篇投稿CVPR的文章&#xff0c;不知道被哪个瞎眼审稿人拒了。同期还有一篇CVPR被接收的工作Segment Any Motion in Videos&#xff0c;看起来不如这篇直白&#xff08;也可能是因为我先看过spotlesssplats的缘故&#xff09;&#xff0c;后面也应该一并介绍了…

MCP 极简入门 - 三分钟 Cline + Smithery 运行 time 服务

文章目录 一、&#x1f680; 初识Smithery&#xff1a;AI服务的新大陆找到心仪的服务 二、Cline 编辑配置文件&#x1f527;1、打开配置文件2. 添加Time Server配置3. 验证配置效果 三、&#x1f4ac; 实战对话&#xff1a;让AI告诉你时间四、服务管理小技巧&#x1f504;&…

基本机动飞行性能

机动飞行时描述飞机在给定构型和发动机工作状态下改变飞行速度、飞行高度和飞行方向的能力 1. 水平加&#xff08;减&#xff09;速 水平加&#xff08;减&#xff09;速性能反映飞机在水平面内改变直线飞行速度的能力。描述水平加&#xff08;减&#xff09;速性能的参数包括…