使用 Hugging Face 的 Transformers 库加载预训练模型遇到的问题

news2024/12/28 21:28:50

题意:

Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch

这个错误信息 "Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch" 通常出现在使用 Hugging Face 的 Transformers 库加载预训练模型时,模型的某些参数与预训练模型检查点(checkpoint)中的参数形状不匹配。

问题背景:

I want to finetune an LLM. I am able to successfully finetune LLM. But when reload the model after save, gets error. Below is the code

import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from trl import DPOTrainer, DPOConfig
def preprocess_data(item):
    return {
        'prompt': 'Instruct: ' + item['prompt'] + '\n',
        'chosen': 'Output: ' + item['chosen'],
        'rejected': 'Output: ' + item['rejected']
    }        

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--beta", type=float, default=0.1)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-6)
    parser.add_argument("--seed", type=int, default=2003)
    parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-14m")
    parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1")
    parser.add_argument("--local_rank", type=int, default=0)

    args = parser.parse_args()

    # Determine device based on local_rank
    device = torch.device("cuda", args.local_rank) if torch.cuda.is_available() else torch.device("cpu")


    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
    ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)

    dataset = load_dataset(args.dataset_name, split="train")
    dataset = dataset.map(preprocess_data)

    # Split the dataset into training and validation sets
    dataset = dataset.train_test_split(test_size=0.1, seed=args.seed)
    train_dataset = dataset['train']
    val_dataset = dataset['test']

    training_args = DPOConfig(
        learning_rate=args.lr,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        logging_steps=10,
        remove_unused_columns=False,
        max_length=1024,
        max_prompt_length=512,
        fp16=True        
    )

    

    # Verify and print embedding dimensions before finetuning
    print("Base model embedding dimension:", model.config.hidden_size)

    model.train()
    ref_model.eval()

    dpo_trainer = DPOTrainer(
        model,
        ref_model,
        beta=args.beta,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        args=training_args,
    )

    dpo_trainer.train()
    # Evaluate
    evaluation_results = dpo_trainer.evaluate()
    print("Evaluation Results:", evaluation_results)

    save_model_name = 'finetuned_model'
    model.save_pretrained(save_model_name)

if __name__ == "__main__":
    main()

Error I was getting as below

return model_class.from_pretrained(
    File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3838, in from_pretrained
        ) = cls._load_pretrained_model(
    File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4349, in _load_pretrained_model
        raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
        RuntimeError: Error(s) in loading state_dict for GPTNeoXForCausalLM:
            size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).
            size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).
            You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

After finetuning, model works perfectly. But after reloading the saved trained model its not working. Any idea why gets this error when reloading the model ?

问题解决:

Instead of

model.save_pretrained(save_model_name)

try this

dpo_trainer.save_model(save_model_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1912711.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis管理禁用命令

在redis数据量比较大时,执行 keys * ,fluashdb 这些命令,会导致redis长时间阻塞,大量请求被阻塞,cpu飙升,严重可能导致redis宕机,数据库雪崩。所以一些命令在生产环境禁止使用。 Redis 禁用命令…

开始尝试从0写一个项目--前端(二)

修改请求路径的位置 将后续以及之前的所有请求全都放在同一个文件夹里面 定义axios全局拦截器 为了后端每次请求都需要向后端传递jwt令牌检验 ps:愁死了,翻阅各种资料,可算是搞定了,哭死~~ src\utils\request.js import axio…

【QML之·基础语法概述】

系列文章目录 文章目录 前言一、QML基础语法二、属性三、脚本四、核心元素类型4.1 元素可以分为视觉元素和非视觉元素。4.2 Item4.2.1 几何属性(Geometry):4.2.2 布局处理:4.2.3 键处理:4.2.4 变换4.2.5 视觉4.2.6 状态定义 4.3 Rectangle4.3.1 颜色 4.4…

互联网3.0时代的变革者:华贝甄选大模型创新之道

在当今竞争激烈的商业世界中,华贝甄选犹如一颗璀璨的明星,闪耀着独特的光芒。 华贝甄选始终将技术创新与研发视为发展的核心驱动力。拥有先进的研发团队和一流设施,积极探索人工智能、大数据、区块链等前沿技术,为用户提供高性能…

Knife4j的介绍与使用

目录 一、简单介绍1.1 简介1.2 主要特点和功能: 二、使用步骤:2.1 添加依赖:2.2 yml数据源配置2.3 创建knife4j配置类2.4 注解的作用 最后 一、简单介绍 1.1 简介 Knife4j 是一款基于Swagger的开源文档管理工具,主要用于生成和管…

【PTA天梯赛】L1-003 个位数统计(15分)

作者:指针不指南吗 专栏:算法刷题 🐾或许会很慢,但是不可以停下来🐾 文章目录 题目题解总结 题目 题目链接 题解 使用string把长度达1000位的数字存起来开一个代表个位数的数组 a[11]倒序计算最后一位,…

第16章 主成分分析:四个案例及课后习题

1.假设 x x x为 m m m 维随机变量,其均值为 μ \mu μ,协方差矩阵为 Σ \Sigma Σ。 考虑由 m m m维随机变量 x x x到 m m m维随机变量 y y y的线性变换 y i α i T x ∑ k 1 m α k i x k , i 1 , 2 , ⋯ , m y _ { i } \alpha _ { i } ^ { T } …

从微软 Word 中提取数据

从 Microsoft Word 文档中提取数据可以通过编程来实现,有几种常见的方法,其中之一是使用 Python 和 python-docx 库。python-docx 是一个处理 .docx 文件(Microsoft Word 文档)的 Python 库,可以读取和操作 Word 文档的…

泛微开发修炼之旅--36通过js控制明细表中同一列中多个浏览框的显示控制逻辑(明细表列中多字段显示逻辑控制)

文章链接:36通过js控制明细表中同一列中多个浏览框的显示控制逻辑(明细表列中多字段显示逻辑控制)

谷粒商城学习笔记-22-分布式组件-SpringCloud-OpenFeign测试远程调用

文章目录 一,OpenFeign的简介二,OpenFeign的使用步骤1,场景说明2,引入依赖2,开启OpenFeign3,编写Feign接口4,使用feign调用远程接口5,验证 错误记录 上一节学习了注册中心&#xff0…

变长输入神经网络设计

我对使用 PyTorch 可以轻松构建动态神经网络的想法很感兴趣,因此我决定尝试一下。 我脑海中的应用程序具有可变数量的相同类型的输入。对于可变数量的输入,已经使用了循环或递归神经网络。但是,这些结构在给定行的输入之间施加了一些顺序或层…

前端面试题31(TCP与UDP区别)

TCP (Transmission Control Protocol) 和 UDP (User Datagram Protocol) 是两种在网络通信中常用的传输层协议,它们在多个方面存在显著差异,主要体现在以下几个方面: 连接方式: TCP 是面向连接的协议。在数据传输开始之前&#xf…

STM32学习历程(day6)

EXTI外部中断使用教程 首先先看下EXTI的框图 看这个框图就能知道要先初始化GPIO外设 那么和前面一样 1、先RCC使能时钟 2、配置GPIO 选择端口为输入模式, 3、配置AFIO,选择我们用的GPIO连接到后面的EXTI 4、配置EXTI,选择边沿触发方式…

前端javascript中的排序算法之选择排序

选择排序(Selection Sort)基本思想: 是一种原址排序法; 将数组分为两个区间:左侧为已排序区间,右侧为未排序区间。每趟从未排序区间中选择一个值最小的元素,放到已排序区间的末尾,从…

从Helm到 Operator:Kubernetes应用管理的进化

🧰Helm 的作用 在开始前需要先对 kubernetes Operator 有个简单的认识。 以为我们在编写部署一些简单 Deployment 的时候只需要自己编写一个 yaml 文件然后 kubectl apply 即可。 apiVersion: apps/v1 kind: Deployment metadata: labels: app: k8s-combat …

Camera Raw:常规工具

在 Camera Raw 窗口右下角提供了四个常用的工具,它们分别是:缩放工具、抓手工具、切换取样器叠加以及切换网格叠加工具。 ◆ ◆ ◆ 缩放工具 Zoom Tool 用于放大或缩小预览图像,便于查看和编辑细节。 快捷键:Z 1、双击“缩放工具…

jvm 06 对象内存结构,指针压缩,调优

01 内存布局 mark word 32bit 4B 64bit 8B 类型指针 klass pointer 开启指针压缩 4B 关闭指针压缩 8B 数组长度 4B 没有这个区域 实例数据 bool 1B 1 true,0 false #define TRUE 1 byte 1B char 2B 1B int 4B float 4B long 8B double 8B 引用类型 开启指针压缩 4B …

部署前端项目

常见部署方式有:静态托管服务、服务器部署 1. 静态托管服务 使用平台部署代码,比如 GitHub。 | 创建一个仓库,仓库名一般是 yourGithubName.github.io。 | 将打包后的静态文件文件上传到仓库。 | 在“Settings”(选项&#xff0…

一文入门云上StarRocks | EMR Serverless StarRocks

一文入门云上StarRocks | EMR Serverless StarRocks 什么是EMR Serverless StarRocksEMR Serverless StarRocks 操作免费开通创建实例连接StarRocks实例临时查询新建连接元数据管理诊断与分析 写在最后 什么是EMR Serverless StarRocks 在使用一个云产品之前,我们首…

C语言 结构体和共用体——结构体类型与结构体变量

目录 问题的提出 数组的解决方法 我们希望的内存分配图 如何声明一个结构体类型? 如何定义一个结构体变量? 用typedef给数据类型定义一个别名 如何定义一个结构体变量? 结构体变量的初始化 问题的提出 数组的解决方法 我们希望的内存…