顶会论文复现:PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS

news2024/11/24 8:56:59

文章目录

  • 1 资料
  • 2 我的总结
  • 3 复现源码
    • 首先你需要有gpt的api接口
    • 安装:
    • 数据集
    • 执行指令
    • 源码
  • 4 结果

1 资料

我复现的源码:https://github.com/Whiffe/test_set_contamination

官网源码:https://github.com/tatsu-lab/test_set_contamination

论文:https://openreview.net/forum?id=KS8mIvetg2

论文翻译:ICLR-2024.Oren.PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS

b站复现视频:https://www.bilibili.com/video/BV14d1CYWE26/

2 我的总结

这篇论文的测试数据污染的方法也是很扯淡的,论文结尾也说了,作者自己的方法得先证明数据集内的题目之间的顺序打乱是否有影响,这不就是扯淡么,训练期间,有个策略就是要每次输入训练时,打乱顺序,训练的时候都打乱了,作者测试期间打乱测的出来个屁呀。这也能发顶会,太离谱了。

还有检测时用的logprobs这个值,这个值的低和高不代表污染程度,整个论文让我感到匪夷所思。

我依旧对作者的源码进行了重构,简单化地跑了以下,不然按照作者的本地化部署,巨量的遍历循环数据集,那得跑到啥时候,即使能等它跑完,对大模型的消耗也是巨大的,我用的api来跑,那可是按量收费的。

3 复现源码

首先你需要有gpt的api接口

# 设置API key和API的基础URL,用于调用 OpenAI 接口
API_KEY = ""  # 替换为你的 API key
BASE_URL = ""  # 替换为API的基本URL

安装:

pip install fire
pip install openai

数据集

同一个数据集被打乱
在这里插入图片描述

执行指令

python main.py --dataset_path "benchmarks/boolq/dev2.jsonl" --log_file_path "results_with_qwen_logprobs.json"

源码

'''
python main.py --dataset_path "benchmarks/boolq/dev2.jsonl" --log_file_path "results_with_qwen_logprobs.json"

此代码用于加载数据集、通过 OpenAI API 计算 token 序列的 logprobs,并通过统计检验比较原始顺序和打乱顺序的 logprobs。
'''

# 导入所需的库和模块
import os
import math
import random 
import numpy as np  # 数组和数值计算库
from scipy.stats import t as tdist  # 导入t分布用于统计检验
from multiprocessing import Process, Queue  # 用于并行处理
from tqdm import tqdm  # 用于显示进度条
import json  # 处理 JSON 数据
import fire  # 用于命令行接口
from openai import OpenAI  # 导入OpenAI库用于调用GPT模型

# 设置API key和API的基础URL,用于调用 OpenAI 接口
API_KEY = ""  # 替换为你的 API key
BASE_URL = ""  # 替换为API的基本URL

# 创建OpenAI客户端,用于后续调用API
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)

# 定义两个 lambda 函数:用于展平嵌套列表和打乱列表
flatten = lambda l : [x for s in l for x in s]  # 展平嵌套列表
shuffle = lambda l : random.sample(l, k=len(l))  # 打乱列表

def load_dataset(dataset_path):
    # 加载数据集函数
    if dataset_path.endswith(".json"):
        # 如果是JSON文件,读取内容
        print("loading from json...")
        with open(dataset_path, "r") as f:
            data = f.read()
            examples = json.loads(data)  # 将JSON格式数据解析为Python对象
            return examples

    # 如果不是JSON,逐行读取文件
    with open(dataset_path, "r") as f:
        lines = f.readlines()  # 读取所有行
    return lines  # 返回行列表

def compute_logprob_of_token_sequence(tokens, context_len=2048, device=0):
    """
    调用 OpenAI API 计算一系列 token 的对数概率 (logprobs)。
    """
    # 将token列表合并成一个输入字符串
    input_text = " ".join(tokens)

    try:
        # 使用 GPT 模型调用API并请求返回logprobs
        response = client.chat.completions.create(
            messages=[{"role": "user", "content": input_text}],
            model="gpt-3.5-turbo",  # 使用GPT-3.5模型
            logprobs=True  # 请求返回 logprobs
        )
        
        # 从响应中提取 logprobs
        logprobs = [token_logprob.logprob for token_logprob in response.choices[0].logprobs.content]
        
        # 计算并返回所有 token 的 logprobs 的和
        total_logprob = sum(logprobs)

        return total_logprob  # 返回 logprobs 总和

    except Exception as e:
        # 如果发生错误,打印错误信息
        print(f"An error occurred: {e}")
        return None  # 返回 None 以表示失败

def worker(context_len, device, main_queue, worker_queue):
    # 工作进程,用于处理多个并行任务
    while True:
        # 从 worker_queue 获取 token 列表、shard ID 和是否是 canonical(原始顺序)
        tokens, shard_id, is_canonical = worker_queue.get()

        if tokens == None:  # 如果收到 None,表示退出
            break

        # 计算 token 序列的 logprobs
        logprob = compute_logprob_of_token_sequence(tokens, context_len, device=device)

        # 将结果放入主进程的队列
        main_queue.put((logprob, shard_id, is_canonical))

def main(dataset_path, context_len=2048, num_shards=5, permutations_per_shard=25,
         random_seed=0, log_file_path=None, max_examples=5000):

    # 设置随机种子,保证可重复性
    random.seed(random_seed)
    np.random.seed(random_seed)

    # 加载数据集
    examples = load_dataset(dataset_path)
    examples = examples[:max_examples]  # 限制加载的示例数量
    num_examples = len(examples)  # 获取数据集大小
    print(f"Loaded {num_examples} examples from {dataset_path}")

    # 对示例进行简单的基于空格的分词
    tokenized_examples = [ex.split() for ex in examples]

    # 使用多进程处理请求(在本例中仅使用一个工作进程)
    processes = []
    main_queue = Queue()  # 主进程队列,用于收集工作进程的结果
    worker_queues = [Queue() for _ in range(1)]  # 工作进程队列

    # 启动工作进程
    p = Process(target=worker, args=(context_len, 0, main_queue, worker_queues[0]))
    processes.append(p)
    p.start()

    # 计算每个分片的大小(将数据集分为多个分片)
    shard_counts = [(x + 1 if i < num_examples % num_shards else x) 
                    for i, x in enumerate([num_examples // num_shards] * num_shards)]
    shard_counts = np.asarray(shard_counts)

    # 生成每个分片的索引
    shard_example_indices = [0] + np.cumsum(shard_counts).tolist()
    for i, (start, end) in enumerate(zip(shard_example_indices, shard_example_indices[1:])):
        shard = tokenized_examples[start:end]

        # 将原始顺序的logprobs请求提交到worker队列
        worker_queues[0].put((
            flatten(shard),  # 展平后的token列表
            i,               # 分片ID
            True))           # 标识这是canonical(原始顺序)

        # 将打乱顺序的logprobs请求提交到worker队列
        for j in range(permutations_per_shard):
            worker_queues[0].put((
                flatten(shuffle(shard)),  # 打乱后的token列表
                i,                        # 分片ID
                False))                   # 标识这是打乱顺序

    # 等待所有请求完成,并显示进度条
    total_work = num_shards * (1 + permutations_per_shard)
    pbar = tqdm(total=total_work)

    canonical_logprobs = [None for _ in range(num_shards)]  # 存储每个分片的 canonical logprobs
    shuffled_logprobs  = [[] for _ in range(num_shards)]    # 存储每个分片的打乱顺序 logprobs

    # 处理worker进程返回的结果
    completed = 0
    while completed < total_work:
        logprob, shard_id, is_canonical = main_queue.get()

        if is_canonical:
            canonical_logprobs[shard_id] = logprob  # 存储原始顺序的logprobs
        else:
            shuffled_logprobs[shard_id].append(logprob)  # 存储打乱顺序的logprobs

        pbar.update(1)  # 更新进度条
        completed += 1

    # 终止工作进程
    worker_queues[0].put((None, None, None))  # 向worker发送退出信号

    for p in processes:
        p.join()  # 等待所有worker进程结束

    # 计算 p-value(p值,用于统计显著性检验)
    canonical_logprobs = np.asarray(canonical_logprobs)  # 转换为numpy数组
    shuffled_logprobs  = np.asarray(shuffled_logprobs)

    # 进行 t 检验,计算 canonical 和 shuffled 之间的差异
    diffs = canonical_logprobs - shuffled_logprobs.mean(axis=1)
    z = np.mean(diffs) / np.std(diffs) * np.sqrt(len(diffs))
    pval = 1 - tdist.cdf(z, df=len(diffs)-1)  # 计算 p 值
    print(f"{pval=}")

    # 将结果写入日志文件(如果指定了log_file_path)
    if log_file_path is not None:
        print(f"Writing logprobs to: {log_file_path}")
        with open(f"{log_file_path}", 'w') as f:
            f.write(json.dumps({
                'pval': pval,
                'permutations_per_shard': permutations_per_shard,
                'num_shards': num_shards,
                'canonical_logprobs': canonical_logprobs.tolist(),
                'shuffled_logprobs': shuffled_logprobs.tolist(),
            }))

if __name__ == '__main__':
  # 使用Fire库,将命令行参数解析并传递给main函数
  fire.Fire(main)

4 结果

在这里插入图片描述
在这里插入图片描述

{“pval”: 0.1696232809691942, “permutations_per_shard”: 25, “num_shards”: 5, “canonical_logprobs”: [-3.3727318899495287, -26.976947896884596, -8.280387231770165, -11.1112375389544, -15.317197114102502], “shuffled_logprobs”: [[-30.909075783811705, -1.9435003494767502, -6.068986559756483, -26.58900908523132, -16.12269960305306, -4.46379730144066, -2.1558121800502787, -2.7554792693991, -31.682284527334303, -2.6273379016797502, -29.87468835264795, -29.607210920316206, -21.57213257741471, -11.95329938606544, -9.972366131973049, -1.09951527892729, -24.01362313224146, -12.456106343552321, -13.67304127957505, -8.3861853631837, -1.19935666177955, -1.3937543557773802, -1.8002136455626179, -18.009020852617073, -17.578153150829802], [-1.048417377427077, -2.3941244789539704, -8.412327189846044, -24.544644694362702, -5.74892321528065, -1.055017764263738, -9.581590557731854, -7.1433768327109, -2.737799236512142, -23.983025399790144, -9.26424030310054, -14.993957307794997, -2.9504655498746724, -12.080805583617936, -30.364195487487766, -1.9539559864239302, -9.9784173216152, -5.28962663901724, -15.477334895809188, -12.511526170812552, -2.6651197975443917, -7.0888789550949, -3.2381118496705326, -9.586995443264478, -19.668003974102017], [-4.233909482393801, -7.066849438078104, -5.8159291705155995, -10.790016564647766, -5.962899019632103, -5.1748830459693185, -2.6900913199189995, -14.64220487293797, -12.072412084641194, -7.0405692357728995, -3.757379365485161, -4.3277333949891, -18.239703727872094, -1.2460728438048796, -1.5030126277381202, -3.4466863958886114, -4.680143685284249, -4.651795277712018, -20.354000748485447, -1.4471048784444498, -10.138775905959701, -18.178129422154928, -8.530598762226427, -7.489915270562131, -3.1585280028892795], [-11.69612743268735, -7.769248518398181, -6.86862614461919, -6.518956643516701, -3.803943939116793, -13.014972477420848, -8.689137998628949, -15.809698635575, -7.99394916393605, -10.31342520238305, -15.928287922934501, -4.502634127164701, -8.7768807739485, -4.220711983509, -28.029167855395826, -2.81686953962755, -9.31479084741635, -11.158157243828, -6.721924684535599, -10.066673909472277, -4.500597344717, -21.480636945940205, -9.300124195747498, -11.9573958015312, -14.577081120902347], [-6.629690439094301, -10022.289585636432, -7.999280733182201, -11.308060008804496, -5.2692347206537, -7.0436708680337015, -4.26530791126701, -4.130906993623899, -5.9630648153422, -7.950204238607298, -7.942466538914552, -12.449199286857947, -25.78205265161044, -17.262547473382632, -5.530209510927001, -8.1570511425078, -5.8230390775959995, -5.532957394563099, -16.9681575425189, -5.454541042652198, -4.4566292699186, -2.14531132561838, -38.43645063084328, -33.65827228043184, -2.714607539457402]]}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2195534.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

tts(text to speech)使用 pyttsx3 实现文本转语音 - python 实现

文本转语音&#xff08;Text-to-Speech&#xff0c;TTS&#xff09;技术是一种将文本信息转换为口语输出的技术。它涉及多个学科&#xff0c;包括声学、语言学、数学信号处理技术和多媒体技术等。TTS技术能够将计算机中的文本信息转换为自然流畅的语音输出&#xff0c;广泛应用…

OJ在线评测系统 后端微服务架构 注册中心 Nacos入门到启动

注册中心 服务架构中的注册中心是一个关键组件&#xff0c;用于管理和协助微服务之间的通信。注册中心的主要职责是服务的注册和发现&#xff0c;确保各个微服务能够相互找到并进行调用。 主要功能&#xff1a; 服务注册&#xff1a;微服务在启动时&#xff0c;将自身信息&am…

OpenHarmony(鸿蒙南向开发)——标准系统方案之瑞芯微RK3566移植案例(下)

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 持续更新中…… 概述 OpenHarmony Camera驱动模型结构 HDI Implementation&#x…

【ubuntu】ubuntu20.04安装显卡驱动

1.安装 点击右下角Apply Changes。 等安装好之后&#xff0c;重启。 现在的nvidia驱动已经很好安装了&#xff0c;比早期时安装出现黑屏等情况好了很多。 2.验证 nvidia-smi

Mybatis plus快速使用

文章目录 Mybatis plus快速使用1.ORM2.mybatis plus介绍3.mybatis plus使用1.添加依赖2.配置信息3.启动类加入 MapperScan&#xff08;“填入mapper包的位置”&#xff09;4.创建user接口&#xff0c;在mapper中加入UserMapper接口5.mybatis-plus crud注解启动springboot项目ma…

基于图像的3D动物重建与生成

一、背景与目标 3D-Fauna 是一款用于基于图像和视频进行四足动物3D重建与生成的开源方案。自然界展示了复杂的相似性与多样性,该方法通过学习来自网上图片的四足动物的3D形态,能够从单张图片生成可动画化的带有纹理的3D网格模型。其最终目标是通过大量扩展现有的解决方案,实…

Ajax面试题:(第一天)

目录 1.说一下网络模型 2.在浏览器地址栏键入URL&#xff0c;按下回车之后会经历以下流程&#xff1a; 3.什么是三次握手和四次挥手&#xff1f; 4.http协议和https协议的区别 1.说一下网络模型 注&#xff1a;各层含义按自己理解即可 2.在浏览器地址栏键入URL&#xff0c;…

mybatis自定义类型处理器

mybatis自定义类型处理器 其实使用MySQL或Oracle数据库很少会遇到自定义类型处理器的情况&#xff0c;之前是因为项目中使用了PGSQL才接触到这块的&#xff0c;这里简单做一下记录 要创建一个自定义的类型处理器&#xff0c;就需要继承BaseTypeHandler类或者实现TypeHandler接…

数据结构 ——— 相交链表(链表的共节点)

题目要求 两个单链表的头节点 headA 和 headB &#xff0c;请找出并返回两个单链表相交的起始节点&#xff0c;如果两个链表不存在相交节点&#xff0c;则返回 NULL 手搓两个相交简易链表 代码演示&#xff1a; struct ListNode* a1 (struct ListNode*)malloc(sizeof(struc…

Git 分支提交同步到主干的详细教程——(包含命令行和idea操作两种方式)

文章目录 Git 分支提交同步到主干的详细教程一、Git 命令行操作1. 确保分支上的代码已提交2. 切换到主干分支3. 拉取最新的主干分支代码4. 合并分支到主干方式一&#xff1a;使用 merge 进行合并方式二&#xff1a;使用 rebase 进行合并 5. 推送合并后的代码到远程主干分支命令…

github 搭建个人导航网

最近搭建了个 个人的导航网&#xff0c;具体内容见下图&#xff0c;欢迎大家访问吖&#xff0c;点击访问 具体实现是使用 vue3 编写&#xff0c;白嫖 github 的 page 部署 首先在 github上创建一个仓库&#xff1a;name.github.io # name是你 github 的名字 然后在本地创建一…

Linux安装部署MySQL8.0加遇着问题解决

1.首先我先给个URL下载MySQL官方网站https://downloads.mysql.com/archives/community/ 2.选择Linux的红帽系统 3.接着选择红帽系统的7版本,x86 4.接着选择MySQL版本,此时我选择8.4.0,下载rpm bundle这个,下载下面这个就好 5.Windows文件上传到Linux系统 rz上传文件命令,找到…

【D3.js in Action 3 精译_030】3.5 给 D3 条形图加注图表标签(下):Krisztina Szűcs 人物专访 + 3.6 本章小结

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第一部分 D3.js 基础知识 第一章 D3.js 简介&#xff08;已完结&#xff09; 1.1 何为 D3.js&#xff1f;1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践&#xff08;上&#xff09;1.3 数据可…

【redis-05】redis保证和mysql数据一致性

redis系列整体栏目 内容链接地址【一】redis基本数据类型和使用场景https://zhenghuisheng.blog.csdn.net/article/details/142406325【二】redis的持久化机制和原理https://zhenghuisheng.blog.csdn.net/article/details/142441756【三】redis缓存穿透、缓存击穿、缓存雪崩htt…

Qt+大恒相机回调图片刷新使用方式

一、前言 上篇文章介绍了如何调用大恒SDK获得回调图片&#xff0c;这篇介绍如何使用这些图片并刷新到界面上。考虑到相机的帧率很高&#xff0c;比如200fps是很高的回调频率。那么我们的刷新频率是做不到这么快&#xff0c;也没必要这么快。一般刷新在60帧左右就够了。 二、思路…

springboot kafka多数据源,通过配置动态加载发送者和消费者

前言 最近做项目&#xff0c;需要支持kafka多数据源&#xff0c;实际上我们也可以通过代码固定写死多套kafka集群逻辑&#xff0c;但是如果需要不修改代码扩展呢&#xff0c;因为kafka本身不处理额外逻辑&#xff0c;只是起到削峰&#xff0c;和数据的传递&#xff0c;那么就需…

Unity_Obfuscator Pro代码混淆工具_学习日志

Unity_Obfuscator Pro代码混淆工具_学习日志 切勿将密码或 API 密钥存储在您附带的应用程序内。 混淆后的热更新暂时没有想到怎么办 Obfuscator 文档 https://docs.guardingpearsoftware.com/manual/Obfuscator/Description.html商店链接Obfuscator Pro&#xff08;大约$70&a…

169.254.0.0/16是什么地址?

169.254.0.0/16是一个链路本地地址&#xff0c;也称为连结本地位址&#xff0c;主要用于局域网内的主机相互通信。‌ 这种地址仅供在网段或广播域中的主机相互通信使用&#xff0c;不需要外部互联网服务‌。 169.254.0.0/16地址段定义在RFC 3927中&#xff0c;当DHCP服务器无法…

永洪BI:企业数字化转型的得力助手

在当今快速变化的商业环境中&#xff0c;数据已成为企业决策的重要依据。随着大数据、云计算和人工智能技术的发展&#xff0c;企业对数据分析的需求日益增长。永洪BI&#xff08;Business Intelligence&#xff09;作为国内领先的商业智能解决方案提供商&#xff0c;以其强大的…

在mac中通过ip连接打印机并实现双面打印

首先需要找到电脑自带的打印。添加打印机。 填写好打印机的ip地址&#xff0c;然后添加。 填写好ip地址后&#xff0c;直接添加就行 添加完打印机后其实就可以打印了。但是有些功能可能实现不了&#xff0c;比如说双面打印。为了实现双面打印的功能&#xff0c;需要再进行设置…