开源项目解读(https://github.com/zjunlp/DeepKE)

news2025/4/3 2:22:53

1.DeepKE 是一个开源的知识图谱抽取与构建工具,支持cnSchema、低资源、长篇章、多模态的知识抽取工具,可以基于PyTorch实现命名实体识别关系抽取属性抽取功能。同时为初学者提供了文档,在线演示, 论文, 演示文稿和海报。

2.下载对应的demo代码

3.准备环境

conda create -n deepke-llm python=3.9
conda activate deepke-llm

cd example/llm
pip install -r requirements.txt

pip install ujson

 4.demo目录介绍

我们直接运行demo.py,就会出现三个选项,每个选项对应一个文件夹

NER(命名实体识别)- 选项1:
基础模型:bert-base-chinese
任务模型:需要从 DeepKE 下载预训练的 NER 模型
位置:neme_entity_recognition/checkpoints/
RE(关系抽取)- 选项2:
基础模型:bert-base-chinese(已有)
任务模型:需要从 DeepKE 下载预训练的 RE 模型
位置:relation_extraction/checkpoints/
AE(属性抽取)- 选项3:
基础模型:bert-base-chinese(已有)
任务模型:需要从 DeepKE 下载预训练的 AE 模型(lm_epoch1.pth)
位置:attributation_extraction/checkpoints/

5.我们先下载本地模型,我直接在本地下载模型

 git clone https://www.modelscope.cn/tiansz/bert-base-chinese.git

修改选项2和选项3中对应的模型的路径为本地路径

 关系抽取的

属性抽取的

 

6.然后去官网下载预训练模型

我发现属性抽取没有提供预训练模型

但是其余两个有,下载地址如下https://drive.google.com/drive/folders/1wb_QIZduKDwrHeri0s5byibsSQrrJTEv

(https://github.com/zjunlp/DeepKE/blob/main/README_CNSCHEMA_CN.md)

7.将下载好的re和ner对应的文件放到对应的位置

1)re

修改relation_extraction中的demo.py的路径和tokenizer,完整代码如下

import os
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm
import ujson as json
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModel, AutoTokenizer
import time
from .process import *

def to_official(preds, features):
    rel2id = json.load(open(f'relation_extraction/data/rel2id.json', 'r'))
    rel2info = json.load(open(f'relation_extraction/data/rel_info.json', 'r'))
    entity = json.load(open(f'relation_extraction/data/output.json', 'r'))
    id2rel = {value: key for key, value in rel2id.items()}

    h_idx, t_idx, title = [], [], []

    for f in features:
        hts = f["hts"]
        h_idx += [ht[0] for ht in hts]
        t_idx += [ht[1] for ht in hts]
        title += [f["title"] for ht in hts]

    res = []

    for i in range(preds.shape[0]):
        pred = preds[i]
        pred = np.nonzero(pred)[0].tolist()
        for p in pred:
            if p != 0:
                h_entity, t_entity = '', ''
                for en in entity[0]['vertexSet'][h_idx[i]]:
                    if len(en['name']) > len(h_entity):
                        h_entity = en['name']
                for en in entity[0]['vertexSet'][t_idx[i]]:
                    if len(en['name']) > len(t_entity):
                        t_entity = en['name']
                res.append(
                    {
                        'h': h_entity,
                        't': t_entity,
                        'r': rel2info[id2rel[p]],
                    }
                )
    return res

class ReadDataset:
    def __init__(self, tokenizer, max_seq_Length: int = 1024,
             transformers: str = 'bert') -> None:
        self.transformers = transformers
        self.tokenizer = tokenizer
        self.max_seq_Length = max_seq_Length

    def read(self, file_in: str):
        save_file = file_in.split('.json')[0] + '_' + self.transformers + '.pkl'
        return read_docred(self.transformers, file_in, save_file, self.tokenizer, self.max_seq_Length)

def read_docred(transfermers, file_in, save_file, tokenizer, max_seq_length=1024):
        max_len = 0
        up512_num = 0
        i_line = 0
        pos_samples = 0
        neg_samples = 0
        features = []
        docred_rel2id = json.load(open(f'relation_extraction/data/rel2id.json', 'r'))
        if file_in == "":
            return None
        with open(file_in, "r") as fh:
            data = json.load(fh)
        if transfermers == 'albert':
            entity_type = ["-", "ORG", "-",  "LOC", "-",  "TIME", "-",  "PER", "-", "MISC", "-", "NUM"]

        for sample in data:
            sents = []
            sent_map = []

            entities = sample['vertexSet']
            entity_start, entity_end = [], []
            mention_types = []
            for entity in entities:
                for mention in entity:
                    sent_id = mention["sent_id"]
                    pos = mention["pos"]
                    entity_start.append((sent_id, pos[0]))
                    entity_end.append((sent_id, pos[1] - 1))
                    mention_types.append(mention['type'])

            for i_s, sent in enumerate(sample['sents']):
                new_map = {}
                for i_t, token in enumerate(sent):
                    tokens_wordpiece = tokenizer.tokenize(token)
                    if (i_s, i_t) in entity_start:
                        t = entity_start.index((i_s, i_t))
                        if transfermers == 'albert':
                            mention_type = mention_types[t]
                            special_token_i = entity_type.index(mention_type)
                            special_token = ['[unused' + str(special_token_i) + ']']
                        else:
                            special_token = ['*']
                        tokens_wordpiece = special_token + tokens_wordpiece

                    if (i_s, i_t) in entity_end:
                        t = entity_end.index((i_s, i_t))
                        if transfermers == 'albert':
                            mention_type = mention_types[t]
                            special_token_i = entity_type.index(mention_type) + 50
                            special_token = ['[unused' + str(special_token_i) + ']']
                        else:
                            special_token = ['*']
                        tokens_wordpiece = tokens_wordpiece + special_token

                    new_map[i_t] = len(sents)
                    sents.extend(tokens_wordpiece)
                new_map[i_t + 1] = len(sents)
                sent_map.append(new_map)

            if len(sents)>max_len:
                max_len=len(sents)
            if len(sents)>512:
                up512_num += 1

            train_triple = {}
            if "labels" in sample:
                for label in sample['labels']:
                    evidence = label['evidence']
                    r = int(docred_rel2id[label['r']])
                    if (label['h'], label['t']) not in train_triple:
                        train_triple[(label['h'], label['t'])] = [
                            {'relation': r, 'evidence': evidence}]
                    else:
                        train_triple[(label['h'], label['t'])].append(
                            {'relation': r, 'evidence': evidence})

            entity_pos = []
            for e in entities:
                entity_pos.append([])
                mention_num = len(e)
                for m in e:
                    start = sent_map[m["sent_id"]][m["pos"][0]]
                    end = sent_map[m["sent_id"]][m["pos"][1]]
                    entity_pos[-1].append((start, end,))

            relations, hts = [], []
            # Get positive samples from dataset
            for h, t in train_triple.keys():
                relation = [0] * len(docred_rel2id)
                for mention in train_triple[h, t]:
                    relation[mention["relation"]] = 1
                    evidence = mention["evidence"]
                relations.append(relation)
                hts.append([h, t])
                pos_samples += 1

            # Get negative samples from dataset
            for h in range(len(entities)):
                for t in range(len(entities)):
                    if h != t and [h, t] not in hts:
                        relation = [1] + [0] * (len(docred_rel2id) - 1)
                        relations.append(relation)
                        hts.append([h, t])
                        neg_samples += 1

            assert len(relations) == len(entities) * (len(entities) - 1)

            if len(hts)==0:
                print(len(sent))
            sents = sents[:max_seq_length - 2]
            input_ids = tokenizer.convert_tokens_to_ids(sents)
            input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)

            i_line += 1
            feature = {'input_ids': input_ids,
                       'entity_pos': entity_pos,
                       'labels': relations,
                       'hts': hts,
                       'title': sample['title'],
                       }
            features.append(feature)

        with open(file=save_file, mode='wb') as fw:
            pickle.dump(features, fw)

        return features

def collate_fn(batch):
    max_len = max([len(f["input_ids"]) for f in batch])
    input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
    input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    input_mask = torch.tensor(input_mask, dtype=torch.float)
    entity_pos = [f["entity_pos"] for f in batch]

    labels = [f["labels"] for f in batch]
    hts = [f["hts"] for f in batch]
    output = (input_ids, input_mask, labels, entity_pos, hts )
    return output

def report(args, model, features):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
    preds = []
    for batch in dataloader:
        model.eval()

        inputs = {'input_ids': batch[0].to(device),
                  'attention_mask': batch[1].to(device),
                  'entity_pos': batch[3],
                  'hts': batch[4],
                  }

        with torch.no_grad():
            pred = model(**inputs)
            pred = pred.cpu().numpy()
            pred[np.isnan(pred)] = 0
            preds.append(pred)

    preds = np.concatenate(preds, axis=0).astype(np.float32)
    preds = to_official(preds, features)
    return preds

class Config(object):
    unet_in_dim=3
    unet_out_dim=256
    max_height=42
    down_dim=256
    channel_type='context-based'
    unet_out_dim=256
    test_batch_size=2

cfg = Config()

def color(text, color="\033[1;34m"): 
    return color+text+"\033[0m"

def doc_re():
    sentence = input(f"Enter the {color('sentence')}: ")
    input_file = 'relation_extraction/input.txt'
    with open(input_file , 'w') as f:
        f.write(sentence)
    txt2json(input_file, 'relation_extraction/data/output.json')
    device = torch.device("cpu")

    bert_path = '/mnt/workspace/DeepKE-demo/bert-base-chinese'
    config = AutoConfig.from_pretrained(bert_path, num_labels=97)
    tokenizer = AutoTokenizer.from_pretrained(bert_path)
    
    Dataset = ReadDataset(tokenizer, 1024, transformers='bert')
    test_file = 'relation_extraction/data/output.json'
    test_features = Dataset.read(test_file)
    
    model = AutoModel.from_pretrained(bert_path, from_tf=False, config=config)
    config.cls_token_id = tokenizer.cls_token_id
    config.sep_token_id = tokenizer.sep_token_id
    config.transformer_type = 'bert'
    
    seed = 111
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    model = DocREModel(config, cfg, model, num_labels=4)

    checkpoint_path = 'relation_extraction/checkpoints/re_bert.pth'
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"预训练模型文件不存在:{checkpoint_path},请确保已下载模型文件并放置在正确位置。")
    
    # 加载预训练权重
    # model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))

    # 加载预训练权重并处理键名不匹配
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('bert.'):
            new_k = 'bert_model.' + k[5:]  # 将 'bert.' 替换为 'bert_model.'
            new_state_dict[new_k] = v
        else:
            new_state_dict[k] = v
    
    # 加载可以加载的权重
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=False)

    model.to(device)
    pred = report(cfg, model, test_features)
    
    with open(input_file.split('.txt')[0]+'.json', "w") as fh:
        json.dump(pred, fh)
    print()
    print(f"The {color('triplets')} are as follow:")
    print()
    for i in pred:
        print(i)
    print()

if __name__ == "__main__":
    doc_re()

同时修改/mnt/workspace/DeepKE-demo/relation_extraction/process/model.py

def encode(self, input_ids, attention_mask,entity_pos):
        config = self.config
        if config.transformer_type == "albert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "bert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "roberta":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id, config.sep_token_id]
        sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

测试句子有格式要求:{[0][PER]欧阳菲菲}演唱的{[1][SONG]没有你的夜晚},出自专辑{[2][ALBUM]拥抱}

最后结果

2)ner

将下载好的checkpoint_bert.zip移动到ner文件夹下并解压缩,然后运行,记得重命名为checkpointints

运行报错,标签老是对不上,重新训练

/mnt/workspace/DeepKE/example/ner/standard路径下

下载数据集

wget 120.27.214.45/Data/ner/standard/data.tar.gz

tar -xzvf data.tar.gz

然后修改配置,改为自己的路径名

/mnt/workspace/DeepKE/example/ner/standard/conf/hydra/model/bert.yaml

安装环境依赖(重新建一个conda环境吧,训练不等同于推理)



conda create -n deepke python=3.8

conda activate deepke

 pip install pip==24.0
在DeepKE源码根目录下(git clone https://github.com/zjunlp/DeepKE.git)
pip install --use-pep517 seqeval
pip install -r requirements.txt

python setup.py install

python setup.py develop
pip install safetensors

/mnt/workspace/DeepKE/example/ner/standard路径下

运行python run_bert.py 

如果用gpu训练的话,需要

pip uninstall torch torchvision torchaudio -y

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

24g显存,使用率是70%,训练了两个小时左右

but,效果并不好

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2326182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

「MethodArgumentTypeMismatchException:前端传递 ‘undefined‘ 导致 Integer 类型转换失败」

遇到的问题: Failed to convert value of type java.lang.String to required type java.lang.Integer; nested exception is java.lang.NumberFormatException: For input string: "undefined" 原因分析: 大致意思就是我传递的参数到后端没…

LabVIEW故障诊断数据处理方法

在LabVIEW故障诊断系统中,数据处理直接决定诊断的准确性和效率。工业现场常面临噪声干扰、数据量大、实时性要求高等挑战,需针对性地选择处理方法。本文结合电机故障诊断、轴承损伤检测等典型案例,详解数据预处理、特征提取、模式识别三大核心…

基于 SpringBoot 的火车订票管理系统

收藏关注不迷路!! 🌟文末获取源码数据库🌟 感兴趣的可以先收藏起来,还有大家在毕设选题(免费咨询指导选题),项目以及论文编写等相关问题都可以给我留言咨询,希望帮助更多…

Python的概论

免责声明 如有异议请在评论区友好交流,或者私信 内容纯属个人见解,仅供学习参考 如若从事非法行业请勿食用 如有雷同纯属巧合 版权问题请直接联系本人进行删改 前言 提示:: 提示:以下是本篇文章正文内容&#xff0c…

构建大语言模型应用:句子转换器(Sentence Transformers)(第三部分)

本系列文章目录 简介数据准备句子转换器(本文)向量数据库搜索与检索大语言模型开源检索增强生成评估大语言模型服务高级检索增强生成 RAG 在之前的博客中,我们学习了为RAG(检索增强生成,Retrieval Augmented Generati…

怎样提升大语言模型(LLM)回答准确率

怎样提升大语言模型(LLM)回答准确率 目录 怎样提升大语言模型(LLM)回答准确率激励与规范类知识关联类情感与语境类逆向思维类:为什么不,反面案例群体智慧类明确指令类示例引导类思维引导类约束限制类反馈交互类:对话激励与规范类 给予奖励暗示:在提示词中暗示模型如果回…

【进阶】vscode 中使用 cmake 编译调试 C++ 工程

基于 MSYS2 的 MinGW-w64 GCC 工具链与 CMake 构建系统,结合VSCode及其扩展插件( ms-vscode.cmake-tools),可实现高效的全流程C开发调试。既可通过 VSCode 可视化界面(命令面板、状态栏按钮)便捷完成配置、…

流影---开源网络流量分析平台(三)(管理引擎部署)

目录 前沿 功能介绍 部署过程 前沿 在上一篇文章中,最后因为虚拟机的资源而没看到最后的效果,而是查看了日志,虽然效果是有了,但后来我等了很久,还是那个转圈的画面,所以我猜测可能是少了什么东西&#…

QT Quick(C++)跨平台应用程序项目实战教程 5 — 界面设计

目录 1.版面设计 2. 自定义按钮 2.1 自定义工具栏按钮 2.2 自定义图标按钮 3. 顶部工具栏 4. 主体 5. 底部工具栏 6. 主文件 7. 最终效果 上一章内容讲解了QML基本使用方法。本章内容继续延续“音乐播放器”项目主线,完成程序的界面设计任务。 1.版面设计…

【微服务架构】SpringCloud Alibaba(三):负载均衡 LoadBalance

文章目录 SpringCloud Alibaba1、核心组件2、优势3、应用场景 一、Loadbalance介绍二、Ribbon和Loadbalance 对比三、整合LoadBlance1、升级版本2、移除ribbon依赖,增加loadBalance依赖 四、自定定义负载均衡器五、重试机制六、源码分析1、猜测源码的实现2、初始化过…

06-02-自考数据结构(20331)- 查找技术-动态查找知识点

自考数据结构动态查找算法主要讲二叉树和平衡二叉树,但是感觉到了,就又续接了一部分,所以这篇备考的小伙伴着重看前两种就可以了。 知识拓扑 知识点介绍 二叉排序树(BST) 定义 二叉排序树(Binary Search Tree)又称二叉查找树,它或者是一棵空树,或者是具有下列性质的二…

Upload-labs 靶场搭建 及一句话木马的原理与运用

1、phpstudy及upload-labs下载 (1)下载phpstudy小皮面板 首先需要软件phpstudy 下载地址 phpStudy下载-phpStudy最新版下载V8.1.1.3 -阔思亮 (2)然后到github网址下载源码压缩包 网址 https://github.com/c0ny1/upload-labs 再…

爬虫的第三天——爬动态网页

一、基本概念 动态网页是指网页内容可以根据用户的操作或者预设条件而实时发生变化的网页。 特点: 用户交互:动态网页能够根据用户的请求而生成不同的内容。内容动态生成:数据来自数据库、API或用户输入。客户端动态渲染:浏览器…

力扣HOT100之矩阵:48. 旋转图像

这道题本来想用剥洋葱的办法的,一直写不对,放弃了。。。直接去看题解,用剥洋葱其实也可以做,就是要从外层处理到内层,每一个边界上的元素为matrix[0].size() - 1个,这样一来,四条边界上的元素个…

uniapp微信小程序获取用户手机号uniCloud云开发版

开发微信小程序,很多时候需要获取用户的手机号,这样方便平台更好的为用户服务,但是微信小程序不允许开发者直接获取用户的手机号,需要用户手动授权才能获取手机号,且需要配合后端进行解密才能获得完整的手机号&#xf…

31天Python入门——第18天:面向对象三大特性·封装继承多态

你好,我是安然无虞。 文章目录 面向对象三大特性1. 封装2. 继承3. 多态4. 抽象基类5. 补充练习 面向对象三大特性 面向对象编程(Object-Oriented Programming, 简称OOP)有三大特性, 分别是封装、继承和多态.这些特性是面向对象编程的基础, …

第十六届蓝桥杯模拟二(串口通信)

由硬件框图可以知道我们要配置LED 和按键 一.LED 先配置LED的八个引脚为GPIO_OutPut,锁存器PD2也是,然后都设置为起始高电平,生成代码时还要去解决引脚冲突问题 二.按键 按键配置,由原理图按键所对引脚要GPIO_Input 生成代码,在文件夹中添加code文件夹,code中添加fun.…

UE5学习笔记 FPS游戏制作32 主菜单,暂停游戏,显示鼠标指针

文章目录 一主菜单搭建UI显示主菜单时,暂停游戏,显示鼠标绑定按钮 二 打开主菜单 一主菜单 搭建UI 添加一个MainUi的控件 添加一个返回游戏的按钮和一个退出游戏的按钮 修改一下样式,放中间 显示主菜单时,暂停游戏&#xff0…

LLM - 开源强化学习框架 OpenR1 的环境配置与训练参数 教程

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/146838740 免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。 OpenR1 是一个开源的强化学习框架,复现 DeepSeek-R1 的训练流程,为研…

蓝桥杯备赛之枚举

用循环等方式依次去枚举所有的数字组合,一一验证是否符合题目的要求 题目链接 0好数 - 蓝桥云课 题目解析 好数的概念: 数的奇数位位奇数,偶数位为偶数,就是一个好数 求输入n里面有多少个好数 题目原理 1> 遍历每个数 2> 每次遍历判断是不是好数 把这…