[当人工智能遇上安全] 11.威胁情报实体识别 (2)基于BiGRU-CRF的中文实体识别万字详解

news2024/11/15 8:59:42

您或许知道,作者后续分享网络安全的文章会越来越少。但如果您想学习人工智能和安全结合的应用,您就有福利了,作者将重新打造一个《当人工智能遇上安全》系列博客,详细介绍人工智能与安全相关的论文、实践,并分享各种案例,涉及恶意代码检测、恶意请求识别、入侵检测、对抗样本等等。只想更好地帮助初学者,更加成体系的分享新知识。该系列文章会更加聚焦,更加学术,更加深入,也是作者的慢慢成长史。换专业确实挺难的,系统安全也是块硬骨头,但我也试试,看看自己未来四年究竟能将它学到什么程度,漫漫长征路,偏向虎山行。享受过程,一起加油~

前文讲解如何实现威胁情报实体识别,利用BiLSTM-CRF算法实现对ATT&CK相关的技战术实体进行提取,是安全知识图谱构建的重要支撑。这篇文章将以中文语料为主,介绍中文命名实体识别研究,并构建BiGRU-CRF模型实现。基础性文章,希望对您有帮助,如果存在错误或不足之处,还请海涵。且看且珍惜!

由于上一篇文章详细讲解ATT&CK威胁情报采集、预处理、BiLSTM-CRF实体识别内容,这篇文章不再详细介绍,本文将在上一篇文章基础上补充:

  • 中文命名实体识别如何实现,以字符为主
  • 以中文CSV文件为语料,介绍其处理过程,中文威胁情报类似
  • 构建BiGRU-CRF模型实现中文实体识别

版本信息:

  • keras-contrib V2.0.8
  • keras V2.3.1
  • tensorflow V2.2.0

常见框架如下图所示:

  • https://aclanthology.org/2021.acl-short.4/

在这里插入图片描述

在这里插入图片描述

文章目录

  • 一.ATT&CK数据采集
  • 二.数据预处理
  • 三.基于BiLSTM-CRF的实体识别
    • 1.安装keras-contrib
    • 2.安装Keras
    • 3.中文实体识别
  • 四.基于BiGRU-CRF的实体识别
  • 五.总结

作者作为网络安全的小白,分享一些自学基础教程给大家,主要是在线笔记,希望您们喜欢。同时,更希望您能与我一起操作和进步,后续将深入学习AI安全和系统安全知识并分享相关实验。总之,希望该系列文章对博友有所帮助,写文不易,大神们不喜勿喷,谢谢!如果文章对您有帮助,将是我创作的最大动力,点赞、评论、私聊均可,一起加油喔!

前文推荐:

  • [当人工智能遇上安全] 1.人工智能真的安全吗?浙大团队外滩大会分享AI对抗样本技术
  • [当人工智能遇上安全] 2.清华张超老师 - GreyOne: Discover Vulnerabilities with Data Flow Sensitive Fuzzing
  • [当人工智能遇上安全] 3.安全领域中的机器学习及机器学习恶意请求识别案例分享
  • [当人工智能遇上安全] 4.基于机器学习的恶意代码检测技术详解
  • [当人工智能遇上安全] 5.基于机器学习算法的主机恶意代码识别研究
  • [当人工智能遇上安全] 6.基于机器学习的入侵检测和攻击识别——以KDD CUP99数据集为例
  • [当人工智能遇上安全] 7.基于机器学习的安全数据集总结
  • [当人工智能遇上安全] 8.基于API序列和机器学习的恶意家族分类实例详解
  • [当人工智能遇上安全] 9.基于API序列和深度学习的恶意家族分类实例详解
  • [当人工智能遇上安全] 10.威胁情报实体识别之基于BiLSTM-CRF的实体识别万字详解
  • [当人工智能遇上安全] 11.威胁情报实体识别 (2)基于BiGRU-CRF的中文实体识别万字详解

作者的github资源:

  • https://github.com/eastmountyxz/When-AI-meet-Security
  • https://github.com/eastmountyxz/AI-Security-Paper

一.ATT&CK数据采集

了解威胁情报的同学,应该都熟悉Mitre的ATT&CK网站,前文已介绍如何采集该网站APT组织的攻击技战术数据。网址如下:

  • http://attack.mitre.org

在这里插入图片描述

第一步,通过ATT&CK网站源码分析定位APT组织名称,并进行系统采集。

在这里插入图片描述

安装BeautifulSoup扩展包,该部分代码如下所示:

在这里插入图片描述

01-get-aptentity.py

#encoding:utf-8
#By:Eastmount CSDN
import re
import requests
from lxml import etree
from bs4 import BeautifulSoup
import urllib.request

#-------------------------------------------------------------------------------------------
#获取APT组织名称及链接

#设置浏览器代理,它是一个字典
headers = {
    'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
        AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.149 Safari/537.36'
}
url = 'https://attack.mitre.org/groups/'

#向服务器发出请求
r = requests.get(url = url, headers = headers).text

#解析DOM树结构
html_etree = etree.HTML(r)
names = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/text()')
print (names)
print(len(names),names[0])
filename = []
for name in names:
    filename.append(name.strip())
print(filename)

#链接
urls = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/@href')
print(urls)
print(len(urls), urls[0])
print("\n")

此时输出结果如下图所示,包括APT组织名称及对应的URL网址。

在这里插入图片描述

第二步,访问APT组织对应的URL,采集详细信息(正文描述)。

在这里插入图片描述

第三步,采集对应的技战术TTPs信息,其源码定位如下图所示。

在这里插入图片描述

第四步,编写代码完成威胁情报数据采集。01-spider-mitre.py 完整代码如下:

#encoding:utf-8
#By:Eastmount CSDN
import re
import requests
from lxml import etree
from bs4 import BeautifulSoup
import urllib.request

#-------------------------------------------------------------------------------------------
#获取APT组织名称及链接

#设置浏览器代理,它是一个字典
headers = {
    'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
        AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.149 Safari/537.36'
}
url = 'https://attack.mitre.org/groups/'

#向服务器发出请求
r = requests.get(url = url, headers = headers).text
#解析DOM树结构
html_etree = etree.HTML(r)
names = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/text()')
print (names)
print(len(names),names[0])
#链接
urls = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/@href')
print(urls)
print(len(urls), urls[0])
print("\n")

#-------------------------------------------------------------------------------------------
#获取详细信息
k = 0
while k<len(names):
    filename = str(names[k]).strip() + ".txt"
    url = "https://attack.mitre.org" + urls[k]
    print(url)

    #获取正文信息
    page = urllib.request.Request(url, headers=headers)
    page = urllib.request.urlopen(page)
    contents = page.read()
    soup = BeautifulSoup(contents, "html.parser")

    #获取正文摘要信息
    content = ""
    for tag in soup.find_all(attrs={"class":"description-body"}):
        #contents = tag.find("p").get_text()
        contents = tag.find_all("p")
        for con in contents:
            content += con.get_text().strip() + "###\n"  #标记句子结束(第二部分分句用)
    #print(content)

    #获取表格中的技术信息
    for tag in soup.find_all(attrs={"class":"table techniques-used table-bordered mt-2"}):
        contents = tag.find("tbody").find_all("tr")
        for con in contents:
            value = con.find("p").get_text()           #存在4列或5列 故获取p值
            #print(value)
            content += value.strip() + "###\n"         #标记句子结束(第二部分分句用)

    #删除内容中的参考文献括号 [n]
    result = re.sub(u"\\[.*?]", "", content)
    print(result)

    #文件写入
    filename = "Mitre//" + filename
    print(filename)
    f = open(filename, "w", encoding="utf-8")
    f.write(result)
    f.close()    
    k += 1

输出结果如下图所示,共整理100个组织信息。

在这里插入图片描述

在这里插入图片描述

每个文件显示内容如下图所示:

在这里插入图片描述

数据标注采用暴力的方式进行,即定义不同类型的实体名称并利用BIO的方式进行标注。通过ATT&CK技战术方式进行标注,后续可以结合人工校正,同时可以定义更多类型的实体。

  • BIO标注
实体名称实体数量示例
APT攻击组织128APT32、Lazarus Group
攻击漏洞56CVE-2009-0927
区域位置72America、Europe
攻击行业34companies、finance
攻击手法65C&C、RAT、DDoS
利用软件487-Zip、Microsoft
操作系统10Linux、Windows

更多标注和预处理请查看上一篇文章。

  • [当人工智能遇上安全] 10.威胁情报实体识别之基于BiLSTM-CRF的实体识别万字详解

常见的数据标注工具:

  • 图像标注:labelme,LabelImg,Labelbox,RectLabel,CVAT,VIA
  • 半自动ocr标注:PPOCRLabel
  • NLP标注工具:labelstudio

温馨提示:
由于网站的布局会不断变化和优化,因此读者需要掌握数据采集及语法树定位的基本方法,以不变应万变。此外,读者可以尝试采集所有锻炼甚至是URL跳转链接内容,请读者自行尝试和拓展!


二.数据预处理

假设存在已经采集和标注好的中文数据集,通常采用按字(Char)分隔,如下图所示,古籍为数据集,当然中文威胁情报也类似。

在这里插入图片描述

数据集划分为训练集和测试集。

在这里插入图片描述

接下来,我们需要读取CSV数据集,并构建汉字词典。关键函数:

  • read_csv(filename):读取语料CSV文件
  • count_vocab(words,labels):统计不重复词典
  • build_vocab():构造词典

完整代码如下:

#encoding:utf-8
# By: Eastmount WuShuai 2024-02-05
import re
import os
import csv
import sys

train_data_path = "data/train.csv"
test_data_path = "data/test.csv"
char_vocab_path = "char_vocabs.txt"    #字典文件
special_words = ['<PAD>', '<UNK>']     #特殊词表示
final_words = []                       #统计词典(不重复出现)
final_labels = []                      #统计标记(不重复出现)

#语料文件读取函数
def read_csv(filename):
    words = []
    labels = []
    with open(filename,encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if len(row)>0: #存在空行报错越界
                word,label = row[0],row[1]
                words.append(word)
                labels.append(label)
    return words,labels

#统计不重复词典
def count_vocab(words,labels):
    fp = open(char_vocab_path, 'a') #注意a为叠加(文件只能运行一次)
    k = 0
    while k<len(words):
        word = words[k]
        label = labels[k]
        if word not in final_words:
            final_words.append(word)
            fp.writelines(word + "\n")
        if label not in final_labels:
            final_labels.append(label)
        k += 1
    fp.close()
   
#读取数据并构造原文字典(第一列)
def build_vocab():
    words,labels = read_csv(train_data_path)
    print(len(words),len(labels),words[:8],labels[:8])
    count_vocab(words,labels)
    print(len(final_words),len(final_labels))

    #测试集
    words,labels = read_csv(test_data_path)
    print(len(words),len(labels))
    count_vocab(words,labels)
    print(len(final_words),len(final_labels))
    print(final_labels)

    #labels生成字典
    label_dict = {}
    k = 0
    for value in final_labels:
        label_dict[value] = k
        k += 1
    print(label_dict)
    return label_dict
    
if __name__ == '__main__':
    build_vocab()

输出结果如下,包括训练集数量,并输出前8行文字及标注,以及不重复的汉字个数,以及实体类别14个。

['晉', '樂', '王', '鮒', '曰', ':', '', '小'] 
['S-LOC', 'B-PER', 'I-PER', 'E-PER', 'O', 'O', '', 'O']
xxx 14

输出类别如下。

['S-LOC', 'B-PER', 'I-PER', 'E-PER', 'O', '', 'B-LOC', 
 'E-LOC', 'S-PER', 'S-TIM', 'B-TIM', 'E-TIM', 'I-TIM', 'I-LOC']

接着实体类别进行编码处理,输出结果如下:

{'S-LOC': 0, 'B-PER': 1, 'I-PER': 2, 'E-PER': 3, 'O': 4, '': 5, 'B-LOC': 6, 
 'E-LOC': 7, 'S-PER': 8, 'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12, 'I-LOC': 13}

需要注意:在实体识别中,我们可以通过调用该函数获取识别的实体类别,关键代码如下。然而,由于真实分析中“O”通常建议编码为0,因此建议重新定义字典编码,更方便我们撰写代码,尤其是中文本遇到换句处理时,上述编码会乱序。

#原计划
from get_data import build_vocab #调取第一阶段函数
label2idx = build_vocab()

#实际情况
label2idx = {'O': 0,
             'S-LOC': 1, 'B-LOC': 2,  'I-LOC': 3,  'E-LOC': 4,
             'S-PER': 5, 'B-PER': 6,  'I-PER': 7,  'E-PER': 8,
             'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12
             }
....
sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]
tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]

最终生成词典char_vocabs.txt。

在这里插入图片描述


三.基于BiLSTM-CRF的实体识别

1.安装keras-contrib

CRF模型作者安装的是 keras-contrib

第一步,如果读者直接使用“pip install keras-contrib”可能会报错,远程下载也报错。

  • pip install git+https://www.github.com/keras-team/keras-contrib.git

甚至会报错 ModuleNotFoundError: No module named ‘keras_contrib’。

在这里插入图片描述

第二步,作者从github中下载该资源,并在本地安装。

  • https://github.com/keras-team/keras-contrib
  • keras-contrib 版本:2.0.8
git clone https://www.github.com/keras-team/keras-contrib.git
cd keras-contrib
python setup.py install

安装成功如下图所示:

在这里插入图片描述

读者可以从我的资源中下载代码和扩展包。

  • https://github.com/eastmountyxz/When-AI-meet-Security

2.安装Keras

同样需要安装keras和TensorFlow扩展包。

在这里插入图片描述

如果TensorFlow下载太慢,可以设置清华大学镜像,实际安装2.2版本。

pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorflow==2.2

在这里插入图片描述

在这里插入图片描述


3.中文实体识别

第一步,数据预处理,包括BIO标记及词典转换。

#encoding:utf-8
# By: Eastmount WuShuai 2024-02-05
# 参考:https://github.com/huanghao128/zh-nlp-demo
import re
import os
import csv
import sys
from get_data import build_vocab #调取第一阶段函数

#------------------------------------------------------------------------
#第一步 数据预处理
#------------------------------------------------------------------------
train_data_path = "data/train.csv"
test_data_path = "data/test.csv"
val_data_path = "data/val.csv"
char_vocab_path = "char_vocabs.txt"   #字典文件(防止多次写入仅读首次生成文件)
special_words = ['<PAD>', '<UNK>']     #特殊词表示
final_words = []                       #统计词典(不重复出现)
final_labels = []                      #统计标记(不重复出现)

#BIO标记的标签 字母O初始标记为0
#label2idx = build_vocab()
label2idx = {'O': 0,
             'S-LOC': 1, 'B-LOC': 2,  'I-LOC': 3,  'E-LOC': 4,
             'S-PER': 5, 'B-PER': 6,  'I-PER': 7,  'E-PER': 8,
             'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12
             }
print(label2idx)

#索引和BIO标签对应
idx2label = {idx: label for label, idx in label2idx.items()}
print(idx2label)

#读取字符词典文件
with open(char_vocab_path, "r") as fo:
    char_vocabs = [line.strip() for line in fo]
char_vocabs = special_words + char_vocabs
print(char_vocabs)

#字符和索引编号对应
idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)}
vocab2idx = {char: idx for idx, char in idx2vocab.items()}
print(idx2vocab)
print(vocab2idx)

输出结果如下所示:

{'O': 0, 'S-LOC': 1, 'B-LOC': 2, 'I-LOC': 3, 'E-LOC': 4, 'S-PER': 5, 'B-PER': 6, 
 'I-PER': 7, 'E-PER': 8, 'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12}
{0: 'O', 1: 'S-LOC', 2: 'B-LOC', 3: 'I-LOC', 4: 'E-LOC', 5: 'S-PER', 6: 'B-PER', 
 7: 'I-PER', 8: 'E-PER', 9: 'S-TIM', 10: 'B-TIM', 11: 'E-TIM', 12: 'I-TIM'}

['<PAD>', '<UNK>', '晉', '樂', '王', '鮒', '曰', ':', '', '小', '旻', ...]
{0: '<PAD>', 1: '<UNK>', 2: '晉', 3: '樂', 4: '王', 5: '鮒', 6: '曰', 7: ':', 8: '', 9: '小', 10: '旻', ... ]
{'<PAD>': 0, '<UNK>': 1, '晉': 2, '樂': 3, '王': 4, '鮒': 5, '曰': 6, ':': 7, '': 8, '小': 9, '旻': 10, ... ]

第二步,读取CSV数据,并获取汉字、标记对应的下标,以下标存储。

#------------------------------------------------------------------------
#第二步 数据读取
#------------------------------------------------------------------------
def read_corpus(corpus_path, vocab2idx, label2idx):
    datas, labels = [], []
    with open(corpus_path, encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        sent_, tag_ = [], []
        for row in reader:
            word,label = row[0],row[1]
            if word!="" and label!="":   #断句
                sent_.append(word)
                tag_.append(label)
                """
                print(sent_) #['晉', '樂', '王', '鮒', '曰', ':']
                print(tag_)  #['S-LOC', 'B-PER', 'I-PER', 'E-PER', 'O', 'O']
                """
            else:                        #vocab2idx[0] => <PAD>
                sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]
                tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]
                """
                print(sent_ids,tag_ids)
                for idx,idy in zip(sent_ids,tag_ids):
                    print(idx2vocab[idx],idx2label[idy])
                #[2, 3, 4, 5, 6, 7] [1, 6, 7, 8, 0, 0]
                #晉 S-LOC 樂 B-PER 王 I-PER 鮒 E-PER 曰 O : O
                """
                datas.append(sent_ids) #按句插入列表
                labels.append(tag_ids)
                sent_, tag_ = [], []
    return datas, labels

#原始数据
train_datas_, train_labels_ = read_corpus(train_data_path, vocab2idx, label2idx)
test_datas_, test_labels_ = read_corpus(test_data_path, vocab2idx, label2idx)

#输出测试结果 (第五句语料)
print(len(train_datas_),len(train_labels_),len(test_datas_),len(test_labels_))
print(train_datas_[5])
print([idx2vocab[idx] for idx in train_datas_[5]])
print(train_labels_[5])
print([idx2label[idx] for idx in train_labels_[5]])

输出结果如下,获取汉字和BIO标记的下标。

[2, 3, 4, 5, 6, 7] [1, 6, 7, 8, 0, 0]
晉 S-LOC 樂 B-PER 王 I-PER 鮒 E-PER 曰 O : O

其中,第5行数据示例如下:

[46, 47, 48, 47, 49, 50, 51, 52, 53, 54, 55, 56]
['齊', '、', '衛', '、', '陳', '大', '夫', '其', '不', '免', '乎', '!']
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0]
['S-LOC', 'O', 'S-LOC', 'O', 'S-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

对应语料如下:

在这里插入图片描述


第三步,数据填充和one-hot编码。

#------------------------------------------------------------------------
#第三步 数据填充 one-hot编码
#------------------------------------------------------------------------
import keras
from keras.preprocessing import sequence

MAX_LEN = 100
VOCAB_SIZE = len(vocab2idx)
CLASS_NUMS = len(label2idx)

#padding data
print('padding sequences')
train_datas = sequence.pad_sequences(train_datas_, maxlen=MAX_LEN)
train_labels = sequence.pad_sequences(train_labels_, maxlen=MAX_LEN)
test_datas = sequence.pad_sequences(test_datas_, maxlen=MAX_LEN)
test_labels = sequence.pad_sequences(test_labels_, maxlen=MAX_LEN)
print('x_train shape:', train_datas.shape)
print('x_test shape:', test_datas.shape)

#encoder one-hot
train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS)
test_labels = keras.utils.to_categorical(test_labels, CLASS_NUMS)
print('trainlabels shape:', train_labels.shape)
print('testlabels shape:', test_labels.shape)

输出结果如下所示:

padding sequences
x_train shape: (xxx, 100)
x_test shape: (xxx, 100)
trainlabels shape: (xxx, 100, 13)
testlabels shape: (xxx, 100, 13)

编码示例如下:

[   0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0 2163  410  294
  980   18]

第四步,构建BiLSTM+CRF模型。

#------------------------------------------------------------------------
#第四步 构建BiLSTM+CRF模型
# pip install git+https://www.github.com/keras-team/keras-contrib.git
# 安装过程详见文件夹截图
# ModuleNotFoundError: No module named ‘keras_contrib’
#------------------------------------------------------------------------
import numpy as np
from keras.models import Sequential
from keras.models import Model
from keras.layers import Masking, Embedding, Bidirectional, LSTM, \
     Dense, Input, TimeDistributed, Activation
from keras_contrib.layers import CRF
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_viterbi_accuracy
from keras import backend as K
from keras.models import load_model
from sklearn import metrics

EPOCHS = 2
EMBED_DIM = 128
HIDDEN_SIZE = 64
MAX_LEN = 100
VOCAB_SIZE = len(vocab2idx)
CLASS_NUMS = len(label2idx)
K.clear_session()
print(VOCAB_SIZE, CLASS_NUMS) #3319 13

#模型构建 BiLSTM-CRF
inputs = Input(shape=(MAX_LEN,), dtype='int32')
x = Masking(mask_value=0)(inputs)
x = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=False)(x) #修改掩码False
x = Bidirectional(LSTM(HIDDEN_SIZE, return_sequences=True))(x)
x = TimeDistributed(Dense(CLASS_NUMS))(x)
outputs = CRF(CLASS_NUMS)(x)
model = Model(inputs=inputs, outputs=outputs)
model.summary()

输出结果如下图所示,显示该模型的结构。

在这里插入图片描述


第五步,模型训练和测试。flag标记变量分别设置为“train”和“test”。

flag = "train"
if flag=="train":
    #模型训练
    model.compile(loss=crf_loss, optimizer='adam', metrics=[crf_viterbi_accuracy])
    model.fit(train_datas, train_labels, epochs=EPOCHS, verbose=1, validation_split=0.1)
    score = model.evaluate(test_datas, test_labels, batch_size=256)
    print(model.metrics_names)
    print(score)
    model.save("bilstm_ner_model.h5")
elif flag=="test":
    #训练模型
    char_vocab_path = "char_vocabs_.txt"      #字典文件
    model_path = "bilstm_ner_model.h5"        #模型文件
    ner_labels = label2idx
    special_words = ['<PAD>', '<UNK>']
    MAX_LEN = 100
    
    #预测结果
    model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)    
    y_pred = model.predict(test_datas)
    y_labels = np.argmax(y_pred, axis=2)         #取最大值
    z_labels = np.argmax(test_labels, axis=2)    #真实值
    word_labels = test_datas                     #真实值
    
    k = 0
    final_y = []       #预测结果对应的标签
    final_z = []       #真实结果对应的标签
    final_word = []    #对应的特征单词
    while k<len(y_labels):
        y = y_labels[k]
        for idx in y:
            final_y.append(idx2label[idx])
        #print("预测结果:", [idx2label[idx] for idx in y])
        
        z = z_labels[k]
        for idx in z:    
            final_z.append(idx2label[idx])
        #print("真实结果:", [idx2label[idx] for idx in z])
        
        word = word_labels[k]
        for idx in word:
            final_word.append(idx2vocab[idx])
        k += 1
    print("最终结果大小:", len(final_y),len(final_z))
    
    n = 0
    numError = 0
    numRight = 0
    while n<len(final_y):
        if final_y[n]!=final_z[n] and final_z[n]!='O':
            numError += 1
        if final_y[n]==final_z[n] and final_z[n]!='O':
            numRight += 1
        n += 1
    print("预测错误数量:", numError)
    print("预测正确数量:", numRight)
    print("Acc:", numRight*1.0/(numError+numRight))
    print("预测单词:", [idx2vocab[idx] for idx in test_datas_[5]])
    print("真实结果:", [idx2label[idx] for idx in test_labels_[5]])
    print("预测结果:", [idx2label[idx] for idx in y_labels[5]][-len(test_datas_[5]):])

训练结果如下所示:

Epoch 1/2
    32/8439 [..............................] - ETA: 6:51 - loss: 2.5549 - crf_viterbi_accuracy: 3.1250e-04
    64/8439 [..............................] - ETA: 3:45 - loss: 2.5242 - crf_viterbi_accuracy: 0.1142
    8439/8439 [==============================] - 118s 14ms/step - loss: 0.1833 - crf_viterbi_accuracy: 0.9591 - val_loss: 0.0688 - val_crf_viterbi_accuracy: 0.9820
Epoch 2/10
    32/8439 [..............................] - ETA: 19s - loss: 0.0644 - crf_viterbi_accuracy: 0.9825
    64/8439 [..............................] - ETA: 42s - loss: 0.0592 - crf_viterbi_accuracy: 0.9845
	...
['loss', 'crf_viterbi_accuracy']
[0.043232945389307574, 0.9868513941764832]

最终测试结果如下所示,由于作者数据集仅放了少量数据,且未进行调参比较,真实数据更多且效果会更好。

预测错误数量: 2183
预测正确数量: 2209
Acc: 0.5029599271402551

预测单词: ['冬', ',', '楚', '公', '子', '罷', '如', '晉', '聘', ',', '且', '涖', '盟', '。']
真实结果: ['O', 'O', 'B-PER', 'I-PER', 'I-PER', 'E-PER', 'O', 'S-LOC', 'O', 'O', 'O', 'O', 'O', 'O']
预测结果: ['O', 'O', 'B-PER', 'E-PER', 'E-PER', 'E-PER', 'O', 'S-LOC', 'O', 'O', 'O', 'O', 'O', 'O']

四.基于BiGRU-CRF的实体识别

接下来构建BiGRU-CRF代码,以完整代码为例,并将预测结果存储在CSV文件上。

#encoding:utf-8
# By: Eastmount WuShuai 2024-02-05
import re
import os
import csv
import sys
from get_data import build_vocab #调取第一阶段函数

#------------------------------------------------------------------------
#第一步 数据预处理
#------------------------------------------------------------------------
train_data_path = "data/train.csv"
test_data_path = "data/test.csv"
val_data_path = "data/val.csv"
char_vocab_path = "char_vocabs.txt"    #字典文件(防止多次写入仅读首次生成文件)
special_words = ['<PAD>', '<UNK>']     #特殊词表示
final_words = []                       #统计词典(不重复出现)
final_labels = []                      #统计标记(不重复出现)

#BIO标记的标签 字母O初始标记为0
#label2idx = build_vocab()
label2idx = {'O': 0,
             'S-LOC': 1, 'B-LOC': 2,  'I-LOC': 3,  'E-LOC': 4,
             'S-PER': 5, 'B-PER': 6,  'I-PER': 7,  'E-PER': 8,
             'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12
             }

#索引和BIO标签对应
idx2label = {idx: label for label, idx in label2idx.items()}

#读取字符词典文件
with open(char_vocab_path, "r") as fo:
    char_vocabs = [line.strip() for line in fo]
char_vocabs = special_words + char_vocabs

#字符和索引编号对应
idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)}
vocab2idx = {char: idx for idx, char in idx2vocab.items()}

#------------------------------------------------------------------------
#第二步 数据读取
#------------------------------------------------------------------------
def read_corpus(corpus_path, vocab2idx, label2idx):
    datas, labels = [], []
    with open(corpus_path, encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        sent_, tag_ = [], []
        for row in reader:
            word,label = row[0],row[1]
            if word!="" and label!="":   #断句
                sent_.append(word)
                tag_.append(label)
            else:                        #vocab2idx[0] => <PAD>
                sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]
                tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]
                datas.append(sent_ids)   #按句插入列表
                labels.append(tag_ids)
                sent_, tag_ = [], []
    return datas, labels

#原始数据
train_datas_, train_labels_ = read_corpus(train_data_path, vocab2idx, label2idx)
test_datas_, test_labels_ = read_corpus(test_data_path, vocab2idx, label2idx)

#------------------------------------------------------------------------
#第三步 数据填充 one-hot编码
#------------------------------------------------------------------------
import keras
from keras.preprocessing import sequence

MAX_LEN = 100
VOCAB_SIZE = len(vocab2idx)
CLASS_NUMS = len(label2idx)

#padding data
print('padding sequences')
train_datas = sequence.pad_sequences(train_datas_, maxlen=MAX_LEN)
train_labels = sequence.pad_sequences(train_labels_, maxlen=MAX_LEN)
test_datas = sequence.pad_sequences(test_datas_, maxlen=MAX_LEN)
test_labels = sequence.pad_sequences(test_labels_, maxlen=MAX_LEN)

#encoder one-hot
train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS)
test_labels = keras.utils.to_categorical(test_labels, CLASS_NUMS)

#------------------------------------------------------------------------
#第四步 构建BiGRU+CRF模型
#------------------------------------------------------------------------
import numpy as np
from keras.models import Sequential
from keras.models import Model
from keras.layers import Masking, Embedding, Bidirectional, LSTM, GRU, \
     Dense, Input, TimeDistributed, Activation
from keras_contrib.layers import CRF
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_viterbi_accuracy
from keras import backend as K
from keras.models import load_model
from sklearn import metrics

EPOCHS = 2
EMBED_DIM = 128
HIDDEN_SIZE = 64
MAX_LEN = 100
VOCAB_SIZE = len(vocab2idx)
CLASS_NUMS = len(label2idx)
K.clear_session()
print(VOCAB_SIZE, CLASS_NUMS)

#模型构建 BiGRU-CRF
inputs = Input(shape=(MAX_LEN,), dtype='int32')
x = Masking(mask_value=0)(inputs)
x = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=False)(x) #修改掩码False
x = Bidirectional(GRU(HIDDEN_SIZE, return_sequences=True))(x)
x = TimeDistributed(Dense(CLASS_NUMS))(x)
outputs = CRF(CLASS_NUMS)(x)
model = Model(inputs=inputs, outputs=outputs)
model.summary()

flag = "test"
if flag=="train":
    #模型训练
    model.compile(loss=crf_loss, optimizer='adam', metrics=[crf_viterbi_accuracy])
    model.fit(train_datas, train_labels, epochs=EPOCHS, verbose=1, validation_split=0.1)
    score = model.evaluate(test_datas, test_labels, batch_size=256)
    print(model.metrics_names)
    print(score)
    model.save("bigru_ner_model.h5")
elif flag=="test":
    #训练模型
    char_vocab_path = "char_vocabs_.txt"      #字典文件
    model_path = "bigru_ner_model.h5"         #模型文件
    ner_labels = label2idx
    special_words = ['<PAD>', '<UNK>']
    MAX_LEN = 100
    
    #预测结果
    model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)    
    y_pred = model.predict(test_datas)
    y_labels = np.argmax(y_pred, axis=2)         #取最大值
    z_labels = np.argmax(test_labels, axis=2)    #真实值
    word_labels = test_datas                     #真实值
    
    k = 0
    final_y = []       #预测结果对应的标签
    final_z = []       #真实结果对应的标签
    final_word = []    #对应的特征单词
    while k<len(y_labels):
        y = y_labels[k]
        for idx in y:
            final_y.append(idx2label[idx])
        z = z_labels[k]
        for idx in z:    
            final_z.append(idx2label[idx])
        word = word_labels[k]
        for idx in word:
            final_word.append(idx2vocab[idx])
        k += 1
    
    n = 0
    numError = 0
    numRight = 0
    while n<len(final_y):
        if final_y[n]!=final_z[n] and final_z[n]!='O':
            numError += 1
        if final_y[n]==final_z[n] and final_z[n]!='O':
            numRight += 1
        n += 1
    print("预测错误数量:", numError)
    print("预测正确数量:", numRight)
    print("Acc:", numRight*1.0/(numError+numRight))
    print("预测单词:", [idx2vocab[idx] for idx in test_datas_[5]])
    print("真实结果:", [idx2label[idx] for idx in test_labels_[5]])
    print("预测结果:", [idx2label[idx] for idx in y_labels[5]][-len(test_datas_[5]):])
    
    #文件存储
    fw = open("Final_BiGRU_CRF_Result.csv", "w", encoding="utf8", newline='')
    fwrite = csv.writer(fw)
    fwrite.writerow(['pre_label','real_label', 'word'])
    n = 0
    while n<len(final_y):
        fwrite.writerow([final_y[n],final_z[n],final_word[n]])
        n += 1
    fw.close()

输出结果如下所示:

['loss', 'crf_viterbi_accuracy']
[0.03543611364953834, 0.9894005656242371]

在这里插入图片描述

生成文件如下图所示:

在这里插入图片描述


五.总结

写到这里这篇文章就结束,希望对您有所帮助,后续将结合经典的Bert进行分享。忙碌的2024,真的很忙,项目本子论文毕业工作,等忙完后好好写几篇安全博客,感谢支持和陪伴,尤其是家人的鼓励和支持, 继续加油!

  • 一.ATT&CK数据采集
  • 二.数据预处理
  • 三.基于BiLSTM-CRF的实体识别
    1.安装keras-contrib
    2.安装Keras
    3.中文实体识别
  • 四.基于BiGRU-CRF的实体识别
  • 五.总结

人生路是一个个十字路口,一次次博弈,一次次纠结和得失组成。得失得失,有得有失,不同的选择,不一样的精彩。虽然累和忙,但看到小珞珞还是挺满足的,感谢家人的陪伴。望小珞能开心健康成长,爱你们喔,继续干活,加油!

在这里插入图片描述

(By:Eastmount 2024-02-07 夜于贵阳 http://blog.csdn.net/eastmount/ )


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1439799.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring 的奇幻起源:从 IoC 容器到 Bean 的魔法世界 ✨

目录 什么是 Spring&#xff1f;为什么它如此流行&#xff1f; IoC 容器&#xff1a;从“依赖倒置”到“控制反转” Bean&#xff1a;IoC 容器中的基本组件 Spring 中的配置方式&#xff1a;XML、注解和 JavaConfig Bean 的作用域和生命周期管理 Bean 的属性装配和自动装配…

[大厂实践] Netflix容器平台内核panic可观察性实践

在某些情况下&#xff0c;K8S节点和Pod会因为出错自动消失&#xff0c;很难追溯原因&#xff0c;其中一种情况就是发生了内核panic。本文介绍了Netflix容器平台针对内核panic所做的可观测性增强&#xff0c;使得发生内核panic的时候&#xff0c;能够导出信息&#xff0c;帮助排…

Java ieda 抽风报错导致无法正常启动项目

Java ieda 抽风报错导致无法正常启动项目 问题描述&#xff1a;新建模块运行时出现下面报错&#xff0c;不能正常启动程序。 Error:Module 你的项目名 production: java.lang.ClassCastException: class org.jetbrains.jps.builders.java.dependencyView.TypeRepr$PrimitiveT…

适用于 Windows 11/10/8.1/8/7 的最佳 SD 卡恢复软件

丢失了 SD 卡中的一些重要照片或文档&#xff0c;并且不知道如何恢复&#xff1f;好吧&#xff0c;别担心&#xff01;&#xff01;以下是一些适用于 Windows 的最佳 SD 卡恢复工具&#xff0c;可增加您检索意外删除、丢失或丢失数据的机会。 什么是 SD 卡恢复软件&#xff1f;…

华为配置访客接入WLAN网络示例(MAC优先的Portal认证)

配置访客接入WLAN网络示例&#xff08;MAC优先的Portal认证&#xff09; 组网图形 图1 配置WLAN MAC优先的Portal认证示例组网图 业务需求组网需求数据规划配置思路配置注意事项操作步骤配置文件 业务需求 某企业为了提高WLAN网络的安全性&#xff0c;采用MAC优先的外置Portal认…

JVM-运行时数据区程序计数器

运行时数据区 Java虚拟机在运行Java程序过程中管理的内存区域&#xff0c;称之为运行时数据区。《Java虚拟机规范》中规定了每一部分的作用。 程序计数器的定义 程序计数器&#xff08;Program Counter Register&#xff09;也叫PC寄存器&#xff0c;每个线程会通过程序计数器…

从Unity到Three.js(安装启动)

发现在3D数字孪生或模拟仿真方向&#xff0c;越来越多的公司倾向使用Web端程序&#xff0c;目前一直都是使用的Unity进行的Web程序开发&#xff0c;但是存在不少问题&#xff0c;比如内存释放、shader差异化、UI控件不支持复制或输入中文等。虽然大多数问题都可以找到解决方案&…

什么是制动电阻器?工作及其应用

电梯、风力涡轮机、起重机、升降机和电力机车的速度控制是非常必要的。因此&#xff0c;制动电阻器是这些应用不可或缺的一部分&#xff0c;因为它们是电动机驱动器中最常用的高功率电阻器&#xff0c;用于控制其速度&#xff0c;在运输、海事和建筑等行业中。 电动火车主要比柴…

【蓝桥杯冲冲冲】Invasion of the Milkweed G

【蓝桥杯冲冲冲】Invasion of the Milkweed G 蓝桥杯备赛 | 洛谷做题打卡day30 文章目录 蓝桥杯备赛 | 洛谷做题打卡day30[USACO09OCT] Invasion of the Milkweed G题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 题解代码我的一些话 [USACO09OCT] Invasion of the Mi…

NIS服务器搭建(管理账户密码验证)

理解&#xff1a;新进100台服务器&#xff0c;通过nis服务器设置各个服务器的用户和密码&#xff0c;而不是分别到100台机器前设置用户名密码&#xff0c;服务器可以统一管理用户名密码&#xff0c;更新等操作 第一&#xff1a;服务器端设置 1.域名设置&#xff1a;dongfang …

(每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理第10章 项目进度管理(三)

博主2023年11月通过了信息系统项目管理的考试&#xff0c;考试过程中发现考试的内容全部是教材中的内容&#xff0c;非常符合我学习的思路&#xff0c;因此博主想通过该平台把自己学习过程中的经验和教材博主认为重要的知识点分享给大家&#xff0c;希望更多的人能够通过考试&a…

格子表单GRID-FORM | 文档网站搭建(VitePress)与部署(Github Pages)

格子表单/GRID-FORM已在Github 开源&#xff0c;如能帮到您麻烦给个星&#x1f91d; GRID-FORM 系列文章 基于 VUE3 可视化低代码表单设计器嵌套表单与自定义脚本交互文档网站搭建&#xff08;VitePress&#xff09;与部署&#xff08;Github Pages&#xff09; 效果预览 格…

【芯片设计- RTL 数字逻辑设计入门 11.1 -- 状态机实现 移位运算与乘法 1】

文章目录 移位运算与乘法状态机简介SystemVerilog中的测试平台VCS 波形仿真 阻塞赋值和非阻塞赋值有限状态机&#xff08;FSM&#xff09;与无限状态机的区别 本篇文章接着上篇文章【芯片设计- RTL 数字逻辑设计入门 11 – 移位运算与乘法】 继续介绍&#xff0c;这里使用状态机…

LeetCode 0993. 二叉树的堂兄弟节点:深度优先搜索(BFS)

【LetMeFly】993.二叉树的堂兄弟节点&#xff1a;深度优先搜索(BFS) 力扣题目链接&#xff1a;https://leetcode.cn/problems/cousins-in-binary-tree/ 在二叉树中&#xff0c;根节点位于深度 0 处&#xff0c;每个深度为 k 的节点的子节点位于深度 k1 处。 如果二叉树的两个…

使用SM4国密加密算法对Spring Boot项目数据库连接信息以及yaml文件配置属性进行加密配置(读取时自动解密)

一、前言 在业务系统开发过程中,我们必不可少的会使用数据库,在应用开发过程中,数据库连接信息往往都是以明文的方式配置到yaml配置文件中的,这样有密码泄露的风险,那么有没有什么方式可以避免呢?方案当然是有的,就是对数据库密码配置的时候进行加密,然后读取的时候再…

25、数据结构/二叉树相关练习20240207

一、二叉树相关练习 请编程实现二叉树的操作 1.二叉树的创建 2.二叉树的先序遍历 3.二叉树的中序遍历 4.二叉树的后序遍历 5.二叉树各个节点度的个数 6.二叉树的深度 代码&#xff1a; #include<stdlib.h> #include<string.h> #include<stdio.h> ty…

生物——文献笔记

生物——文献笔记 文章目录 前言藻类群体遗传学研究和进展&#xff08;综述&#xff09;海洋动物群体遗传学的研究进展1. 影响群体基因频率的因素2. 根据自然群体的繁殖体系&#xff0c;海洋动物群体遗传类型可分为以下几类3. 海洋动物群体遗传研究中常用的遗传标记4. 研究展望…

UML 2.5图形库

UML 2.5图形库 drawio是一款强大的图表绘制软件&#xff0c;支持在线云端版本以及windows, macOS, linux安装版。 如果想在线直接使用&#xff0c;则直接输入网址drawon.cn或者使用drawon(桌案), drawon.cn内部完整的集成了drawio的所有功能&#xff0c;并实现了云端存储&#…

幻兽帕鲁服务器全自动部署教程,小白也能轻松上手

幻兽帕鲁太火了&#xff0c;官方palworld服务器不稳定&#xff1f;不如自建服务器&#xff0c;基于腾讯云幻兽帕鲁服务器成本32元全自动部署幻兽帕鲁服务器&#xff0c;超简单有手就行&#xff0c;全程自动化一键部署10秒钟即可搞定&#xff0c;无需玩家手动部署幻兽帕鲁游戏程…

深入探究 HTTP 简化:httplib 库介绍

✏️心若有所向往&#xff0c;何惧道阻且长 文章目录 简介特性主要类介绍httplib::Server类httplib::Client类httplib::Request类httplib::Response类 示例服务器客户端 总结 简介 在当今的软件开发中&#xff0c;与网络通信相关的任务变得日益普遍。HTTP&#xff08;Hypertext…