Spider 数据集上实现nlp2sql训练任务

news2025/2/11 2:52:12

NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中,例如数据分析、业务智能和问答系统。

任务目标
  • 理解自然语言: 理解用户输入的自然语言问题,包括意图、实体和上下文。
  • 生成 SQL 查询: 将理解后的信息转换为正确的 SQL 查询,以从数据库中检索所需的数据。

例如

输入: 用户的自然语言问题,“获取 Gelderland 区的总人口。”

输出: 对应的 SQL 查询,SELECT population FROM districts WHERE name = 'Gelderland';

Spider 是一个难度最大数据集

耶鲁大学在2018年新提出的一个大规模的NL2SQL(Text-to-SQL)数据集。
该数据集包含了10,181条自然语言问句、分布在200个独立数据库中的5,693条SQL,内容覆盖了138个不同的领域。
涉及的SQL语法最全面,是目前难度最大的NL2SQL数据集。

下载查看spider数据集内容

Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer

Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer

Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc

...

首先需要转换为Spider的标准格式(参考tables.jsontrain.json):

{
  "db_id": "concert_singer",
  "question": "Show name, country, age...",
  "query": "SELECT name, country, age FROM singer ORDER BY age DESC",
  "schema": {
    "table_names": ["singer"],
    "column_names": [
      [0, "name", "text"],
      [0, "country", "text"],
      [0, "age", "int"]
    ]
  }
}

拆分为table.json的原因可能涉及到数据组织和重用。每个数据库的结构(表、列、外键)在多个问题中都会被重复使用。如果每个问题都附带完整的schema信息,会导致数据冗余,增加存储和处理的开销。所以,将schema单独存储为table.json,可以让不同的数据条目引用同一个数据库模式,减少重复数据。拆分后的结构需要更高效的数据管理,例如在训练模型时,根据每个问题的db_id去table.json中查找对应的schema信息。这样做的好处是当多个问题属于同一个数据库时,不需要每次都重复加载schema提高了效率。

column_names 表示数据库表中每一列的详细信息。具体来说,column_names 是一个列表,其中每个元素都是一个包含三个部分的子列表:

  1. 表索引(0):表示该列属于哪个表。在这个例子中,所有列都属于第一个表(索引为 0)。
  2. 列名("name"、"country"、"age"):表示列的名称。
  3. 数据类型("text"、"int"):表示该列的数据类型,例如文本(text)或整数(int)。

实现下面逻辑转换原始数据

def extract_columns_from_sql(sql):
    # 使用正则表达式匹配 SELECT 语句中的列名
    match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)
    if match:
        # 提取列名
        columns = match.group(1).split(",")
        # 构建 column_names 列表
        column_names = []
        for index, column in enumerate(columns):
            column = column.strip()  # 去除多余的空格
            data_type = "text"  # 默认数据类型为 text,可以根据需要修改
            # 添加到 column_names 列表,假设所有列类型为 text
            column_names.append([0, column, data_type])
        return column_names
    return []

# 从 dev.sql 文件读取数据
def load_sql_data(file_path):
    data_list = []
    with open(file_path, 'r', encoding='utf-8') as f:  # 指定编码为 UTF-8
        lines = f.readlines()
        for i in range(0, len(lines), 3):  # 每三行一组
            question_line = lines[i].strip()
            sql_line = lines[i + 1].strip()

            if not question_line or not sql_line:
                continue

            # 提取问题和 SQL
            question = question_line.split(': ', 1)[1].strip()  # 获取问题内容
            sql = sql_line.split(': ', 1)[1].strip()  # 获取 SQL 查询

            # 提取表名
            db_id = question_line.split('|||')[-1].strip()  # 从问题行获取表名
            question = question.split('|||')[0].strip()

            target_sql = preprocess(question, db_id, sql)

            data_list.append({
                "input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}",
                "target_sql": json.dumps(target_sql)  # 将目标 SQL 转换为 JSON 格式字符串
            })
    return data_list

选择Tokenizer.from_pretrained("t5-base") 是用于加载 T5(Text-to-Text Transfer Transformer)模型的分词器。T5 是一个强大的自然语言处理模型,能够处理各种文本任务(如翻译、摘要、问答等),并且将所有任务视为文本到文本的转换。

from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-base")

def preprocess(question, db_id, sql):
    # 提取列名
    column_names = extract_columns_from_sql(sql)

    # 构建目标格式
    target_sql = {
        "db_id": db_id,
        "question": question,
        "query": sql,
        "schema": {
            "table_names": [db_id],
            "column_names": column_names
        }
    }
    return target_sql# 

示例数据
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"

input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)

所有nlp任务都涉及的需要token化,使用t5-base 做tokenize

def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input_text"],
        max_length=512,
        truncation=True,
        padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["target_sql"],
            max_length=512,
            truncation=True,
            padding="max_length"
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

使用 tokenizer.as_target_tokenizer() 上下文管理器,确保目标文本(即 SQL 查询)被正确处理。目标文本也经过编码,转换为 token IDs,并同样进行填充和截断。将目标文本的编码结果(token IDs)存储在 model_inputs["labels"] 中。这是模型在训练时需要的输出,用于计算损失。最终返回一个字典 model_inputs,它包含了模型的输入和对应的标签。这种结构使得模型在训练时可以直接使用。

最后组织下训练代码

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 加载模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")

# 训练参数
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=100,
    predict_with_generate=True,
    run_name="spider"
)

# 开始训练
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,
    eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,
    data_collator=DataCollatorForSeq2Seq(tokenizer)
)

trainer.train()

这里使用的是Seq2SeqTrainer, 它是 Hugging Face 的 transformers 库中用于序列到序列(Seq2Seq)任务的训练器。它为处理诸如翻译、文本生成和问答等任务提供了一个高层次的接口,简化了训练过程。以下是 Seq2SeqTrainer 的主要功能和特点:

  1. 简化训练流程Seq2SeqTrainer 封装了许多常见的训练步骤,如数据加载、模型训练、评估和预测,使得用户可以更专注于模型和数据,而不必处理繁琐的训练细节。

  2. 支持多种训练参数: 通过 Seq2SeqTrainingArguments 类,可以灵活配置训练参数,如学习率、批量大小、训练轮数、评估策略等。

  3. 自动处理填充和截断: 在处理输入和输出序列时,Seq2SeqTrainer 可以自动填充和截断序列,以确保它们适应模型的输入要求。

  4. 集成评估和监控: 支持在训练过程中进行模型评估,并可以根据评估指标(如损失)监控训练进度。用户可以设置评估频率和评估数据集

开始训练,进行100次epoch

训练监控在 Weights & Biases ,Seq2SeqTrainer 能够向 Weights & Biases (wandb) 传输训练监控数据,主要是因为它内置了与 wandb 的集成。以下是一些关键点,解释了这一过程:

  1. 自动集成:当你使用 Seq2SeqTrainer 时,它会自动检测 wandb 的安装并在初始化时配置相关设置。这意味着你无需手动设置 wandb。

  2. 回调功能Trainer 类提供了回调功能,可以在训练过程中记录各种指标(如损失、准确率等)。这些指标会被自动发送到 wandb。

  3. 配置管理training_args 中的参数可以指定 wandb 的项目名称、运行名称等,从而更好地组织和管理实验。

  4. 训练循环:在每个训练和评估周期结束时,Trainer 会调用相应的回调函数,将重要的训练信息(如损失、学习率等)记录到 wandb。

  5. 可视化:通过 wandb,你可以实时监控训练过程,包括损失曲线、模型性能等,帮助你更好地理解模型的训练动态。

多次试验还可以比较训练性能

训练结束, 损失收敛到0.05410315271151268

{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00,  2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs

测试下预测能力

import os
from transformers import T5Tokenizer, T5ForConditionalGeneration

# 设置 NCCL 环境变量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

# 加载分词器
tokenizer = T5Tokenizer.from_pretrained("t5-base")


model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")

def generate_sql(question, db_id):
    input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"
    input_ids = tokenizer.encode(input_text, return_tensors="pt")  # 使▒~T▒ PyTorch ▒~Z~D▒| ▒~G~O▒| ▒▒~O
    output = model.generate(
        input_ids,
        max_length=512,
        num_beams=5,  # 或者尝试其他解码策略
        early_stopping=True
    )

    print('output', output)
    generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_sql

question = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)

输出结果

evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2296091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【DeepSeek】DeepSeek概述 | 本地部署deepseek

目录 1 -> 概述 1.1 -> 技术特点 1.2 -> 模型发布 1.3 -> 应用领域 1.4 -> 优势与影响 2 -> 本地部署 2.1 -> 安装ollama 2.2 -> 部署deepseek-r1模型 1 -> 概述 DeepSeek是由中国的深度求索公司开发的一系列人工智能模型&#xff0c;以其…

ASP.NET Core 使用 WebClient 从 URL 下载

本文使用 ASP .NET Core 3.1&#xff0c;但它在.NET 5、 .NET 6和.NET 8上也同样适用。如果使用较旧的.NET Framework&#xff0c;请参阅本文&#xff0c;不过&#xff0c;变化不大。 如果想要从 URL 下载任何数据类型&#xff0c;请参阅本文&#xff1a;HttpClient 使用WebC…

【CubeMX-HAL库】STM32F407—无刷电机学习笔记

目录 简介&#xff1a; 学习资料&#xff1a; 跳转目录&#xff1a; 一、工程创建 二、板载LED 三、用户按键 四、蜂鸣器 1.完整IO控制代码 五、TFT彩屏驱动 六、ADC多通道 1.通道确认 2.CubeMX配置 ①开启对应的ADC通道 ②选择规则组通道 ③开启DMA ④开启ADC…

vue3 点击图标从相册选择二维码图片,并使用jsqr解析二维码(含crypto-js加密解密过程)

vue3 点击图标从相册选择二维码图片&#xff0c;并使用jsqr解析二维码&#xff08;含crypto-js加密解密过程&#xff09; 1.安装 jsqr 和 crypto-js npm install -d jsqr npm install crypto-js2.在util目录下新建encryptionHelper.js文件&#xff0c;写加密解密方法。 // e…

kafka 3.5.0 raft协议安装

前言 最近做项目&#xff0c;需要使用kafka进行通信&#xff0c;且只能使用kafka&#xff0c;笔者没有测试集群&#xff0c;就自己搭建了kafka集群&#xff0c;实际上笔者在很早之前就搭建了&#xff0c;因为当时还是zookeeper&#xff08;简称ZK&#xff09;注册元数据&#…

前后端服务配置

1、安装虚拟机&#xff08;VirtualBox或者vmware&#xff09;&#xff0c;在虚拟机上配置centos(选择你需要的Linux版本)&#xff0c;配置如nginx服务器等 1.1 VMware 下载路径Sign In注册下载 1.2 VirtualBox 下载路径https://www.virtualbox.org/wiki/Downloads 2、配置服…

在阿里云ECS上一键部署DeepSeek-R1

DeepSeek-R1 是一款开源模型&#xff0c;也提供了 API(接口)调用方式。据 DeepSeek介绍&#xff0c;DeepSeek-R1 后训练阶段大规模使用了强化学习技术&#xff0c;在只有极少标注数据的情况下提升了模型推理能力&#xff0c;该模型性能对标 OpenAl o1 正式版。DeepSeek-R1 推出…

git SourceTree 使用

Source Tree 使用原理 文件的状态 创建仓库和提交 验证 再克隆的时候发发现一个问题&#xff0c;就是有一个 这个验证&#xff0c;起始很简单 就是 gitee 的账号和密码&#xff0c;但是要搞清楚的是账号不是名称&#xff0c;我之前一直再使用名称登录老是出问题 这个很简单的…

游戏引擎学习第94天

仓库:https://gitee.com/mrxiao_com/2d_game_2 回顾上周的渲染器工作 完成一款游戏的开发&#xff0c;完全不依赖任何库和引擎&#xff0c;这样我们能够全面掌握游戏的开发过程&#xff0c;确保没有任何细节被隐藏。我们将深入探索每一个环节&#xff0c;犹如拿着手电筒翻看床…

win32汇编环境,结构体的使用示例二

;运行效果 ;win32汇编环境,结构体的使用示例二 ;举例说明结构体的定义&#xff0c;如何访问其中的成员&#xff0c;使用assume指令指向某个结构体&#xff0c;计算结构数组所需的偏移量得到某个成员值等 ;直接抄进RadAsm可编译运行。重要部分加备注。 ;下面为asm文件 ;>>…

DeepSeek从入门到精通教程PDF清华大学出版

DeepSeek爆火以来&#xff0c;各种应用方式层出不穷&#xff0c;对于很多人来说&#xff0c;还是特别模糊&#xff0c;有种雾里看花水中望月的感觉。 最近&#xff0c;清华大学新闻与传播学院新媒体研究中心&#xff0c;推出了一篇DeepSeek的使用教程&#xff0c;从最基础的是…

【PDF提取内容】如何批量提取PDF里面的文字内容,把内容到处表格或者批量给PDF文件改名,基于C++的实现方案和步骤

以下分别介绍基于 C 批量提取 PDF 里文字内容并导出到表格&#xff0c;以及批量给 PDF 文件改名的实现方案、步骤和应用场景。 批量提取 PDF 文字内容并导出到表格 应用场景 文档数据整理&#xff1a;在处理大量学术论文、报告等 PDF 文档时&#xff0c;需要提取其中的关键信…

SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来Matlab实现

SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来Matlab实现 目录 SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来Matlab实现预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来&#xff08;优…

大模型推理——MLA实现方案

1.整体流程 先上一张图来整体理解下MLA的计算过程 2.实现代码 import math import torch import torch.nn as nn# rms归一化 class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Pa…

大数据项目2:基于hadoop的电影推荐和分析系统设计和实现

前言 大数据项目源码资料说明&#xff1a; 大数据项目资料来自我多年工作中的开发积累与沉淀。 我分享的每个项目都有完整代码、数据、文档、效果图、部署文档及讲解视频。 可用于毕设、课设、学习、工作或者二次开发等&#xff0c;极大提升效率&#xff01; 1、项目目标 本…

Windows逆向工程入门之汇编环境搭建

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 Visual Studio逆向工程配置 基础环境搭建 Visual Studio 官方下载地址安装配置选项(后期可随时通过VS调整) 使用C的桌面开发 拓展可选选项 MASM汇编框架 配置MASM汇编项目 创建新项目 选择空…

gc buffer busy acquire导致的重大数据库性能故障

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 作者&#xff1a;IT邦德 中国DBA联盟(ACDU)成员&#xff0c;10余年DBA工作经验 Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主&#xff0c;全网粉丝10万 擅长主流Oracle、MySQL、PG、高斯…

Formily 如何进行表单验证

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

安宝特方案 | AR眼镜:远程医疗的“时空折叠者”,如何为生命争夺每一分钟?

行业痛点&#xff1a;当“千里求医”遇上“资源鸿沟” 20世纪50年代&#xff0c;远程会诊的诞生曾让医疗界为之一振——患者不必跨越山河&#xff0c;专家无需舟车劳顿&#xff0c;一根电话线、一张传真纸便能架起问诊的桥梁。然而&#xff0c;传统远程医疗的局限也日益凸显&a…

使用git commit时‘“node“‘ 不是内部或外部命令,也不是可运行的程序

第一种&#xff1a; 使用git commit -m "xxx"时会报错&#xff0c;我看网上的方法是在命令行后面添加--no-verify&#xff1a;git commit -m "主题更新" --no-verify&#xff0c;但是不可能每次都添加。 最后解决办法是&#xff1a;使用git config --lis…