NLP_知识图谱_图谱问答实战

news2024/12/23 19:12:35

文章目录

  • 图谱问答
    • NER
    • ac自动机
    • 实体链接
    • 实体消歧
  • 多跳问答
    • neo4j_graph执行流程
    • 结构图![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/1577c1d9c9e342b3acbf79824aae980f.png)
      • company_data![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/20f567d877c743b49546e50caad92fba.png)
  • 代码与数据
    • 先启动neo4j图数据库
    • import_data
    • create_question_data
    • data_process
    • ac_automaton
    • torch_utils
    • text_cnn
    • train
    • main
  • 图谱问答实战小结


图谱问答

图谱问答有很多种情况,例如根据实体和关系查询尾实体,或者根据实体查询关系,甚至还会出现多跳的情况,不同的情况采用的方法略有不同,我们先来看最简单的情况,根据头实体和关系查询尾实体。

1、找到实体与关系,可以采用BIO的形式做NER,也可以直接使用分类的方法
2、实体链接,如果遇到相同名字的实体,需要做一个消歧

NER

目前的NER的方式很多,基本的结构都是encoder+crf层

ac自动机

1、构建前缀树
2、给前缀树加上fail指针
节点i的fail指针,如果在第一层,则指向root节点,其它情况指向其父节点的fail指针指向的节点的相同节点
在这里插入图片描述

有如下的几个模式串:she he say shr her
匹配串:yasherhs
在这里插入图片描述

实体链接

实体链接包括两个步骤:
Candidate Entity Generation、Entity Disambiguation

找到候选实体后,下一步就是实体消歧

实体消歧

实体消歧,这里我们使用的是匹配的方法:
1、使用孪生网络,计算相似度
2、对问题和候选集做embedding,计算余弦相似度

多跳问答

在这里插入图片描述

neo4j_graph执行流程

1、先执行import_data.py脚本,把company_data下面的数据导入到neo4j

2、执行gnn/saint.py脚本进行节点分类

3、company.csv文件是每个节点的属性

结构图在这里插入图片描述

company_data在这里插入图片描述

截图举几个例子(ps:数据为虚假,作为学习使用):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码与数据

先启动neo4j图数据库

操作流程:WIN+R,cmd,neo4j.bat concole
在这里插入图片描述
在这里插入图片描述

import_data

import os
from py2neo import Node, Subgraph, Graph, Relationship, NodeMatcher
from tqdm import tqdm
import pandas as pd
import numpy as np

#graph = Graph("http://127.0.0.1:7474", auth=("neo4j", "qwer"))

#graph = Graph("http://127.0.0.1:7474", auth=("neo4j", "qwer"))

#uri = 'bolt://localhost:7687'
#graph = Graph(uri, auth=("neo4j", "password"), port= 7687, secure=True)

#uri = uri = 'http://localhost:7687'
#graph = Graph(uri, auth=("neo4j", "qwer"), port= 7687, secure=True, name= "StellarGraph")

import py2neo
default_host = os.environ.get("STELLARGRAPH_NEO4J_HOST")

# Create the Neo4j Graph database object; the arguments can be edited to specify location and authentication

graph = py2neo.Graph(host=default_host, port=7687, user='neo4j', password='qwer')

def import_company():
    df = pd.read_csv('company_data/公司.csv')
    eid = df['eid'].values
    name = df['companyname'].values

    nodes = []
    data = list(zip(eid, name))
    for eid, name in tqdm(data):
        profit = np.random.randint(100000, 100000000, 1)[0]
        node = Node('company', name=name, profit=int(profit), eid=eid)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_person():
    df = pd.read_csv('company_data/人物.csv')
    pid = df['personcode'].values
    name = df['personname'].values

    nodes = []
    data = list(zip(pid, name))
    for eid, name in tqdm(data):
        age = np.random.randint(20, 70, 1)[0]
        node = Node('person', name=name, age=int(age), pid=str(eid))
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_industry():
    df = pd.read_csv('company_data/行业.csv')
    names = df['orgtype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('industry', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_assign():
    df = pd.read_csv('company_data/分红.csv')
    names = df['schemetype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('assign', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_violations():
    df = pd.read_csv('company_data/违规类型.csv')
    names = df['gooltype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('violations', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_bond():
    df = pd.read_csv('company_data/债券类型.csv')
    names = df['securitytype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('bond', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


# def import_dishonesty():
#     node = Node('dishonesty', name='失信')
#     graph.create(node)


def import_relation():
    df = pd.read_csv('company_data/公司-人物.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    pid = df['pid'].values
    post = df['post'].values
    relations = []
    data = list(zip(eid, pid, post))
    for e, p, po in tqdm(data):
        company = matcher.match('company', eid=e).first()
        person = matcher.match('person', pid=str(p)).first()
        if company is not None and person is not None:
            relations.append(Relationship(company, po, person))

    graph.create(Subgraph(relationships=relations))
    print('import company-person relation succeeded')

    df = pd.read_csv('company_data/公司-行业.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['industry'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        industry = matcher.match('industry', name=str(n)).first()
        if company is not None and industry is not None:
            relations.append(Relationship(company, '行业类型', industry))

    graph.create(Subgraph(relationships=relations))
    print('import company-industry relation succeeded')

    df = pd.read_csv('company_data/公司-分红.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['assign'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        assign = matcher.match('assign', name=str(n)).first()
        if company is not None and assign is not None:
            relations.append(Relationship(company, '分红方式', assign))

    graph.create(Subgraph(relationships=relations))
    print('import company-assign relation succeeded')

    df = pd.read_csv('company_data/公司-违规.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['violations'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        violations = matcher.match('violations', name=str(n)).first()
        if company is not None and violations is not None:
            relations.append(Relationship(company, '违规类型', violations))

    graph.create(Subgraph(relationships=relations))
    print('import company-violations relation succeeded')

    df = pd.read_csv('company_data/公司-债券.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['bond'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        bond = matcher.match('bond', name=str(n)).first()
        if company is not None and bond is not None:
            relations.append(Relationship(company, '债券类型', bond))

    graph.create(Subgraph(relationships=relations))
    print('import company-bond relation succeeded')

    # df = pd.read_csv('company_data/公司-失信.csv')
    # matcher = NodeMatcher(graph)
    # eid = df['eid'].values
    # rel = df['dishonesty'].values
    # relations = []
    # data = list(zip(eid, rel))
    # for e, r in tqdm(data):
    #     company = matcher.match('company', eid=e).first()
    #     dishonesty = matcher.match('dishonesty', name='失信').first()
    #     if company is not None and dishonesty is not None:
    #         if pd.notna(r):
    #             if int(r) == 0:
    #                 relations.append(Relationship(company, '无', dishonesty))
    #             elif int(r) == 1:
    #                 relations.append(Relationship(company, '有', dishonesty))
    #
    # graph.create(Subgraph(relationships=relations))
    # print('import company-dishonesty relation succeeded')


def import_company_relation():
    df = pd.read_csv('company_data/公司-供应商.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2) and e1 != e2:
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '供应商', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-supplier relation succeeded')

    df = pd.read_csv('company_data/公司-担保.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2) and e1 != e2:
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '担保', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-guarantee relation succeeded')

    df = pd.read_csv('company_data/公司-客户.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2):
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '客户', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-customer relation succeeded')


def delete_relation():
    cypher = 'match ()-[r]-() delete r'
    graph.run(cypher)


def delete_node():
    cypher = 'match (n) delete n'
    graph.run(cypher)


def import_data():
    import_company()
    import_company_relation()

    import_person()
    import_industry()
    import_assign()
    import_violations()
    import_bond()
    # import_dishonesty()

    import_relation()


def delete_data():
    delete_relation()
    delete_node()
    print('delete data succeeded')


if __name__ == '__main__':
    profit = np.random.randint(100000, 100000000, 10).tolist()

    delete_data()
    import_data()

在这里插入图片描述

create_question_data

from py2neo import Graph
import numpy as np
import pandas as pd

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))

# import os
# import py2neo
# default_host = os.environ.get("STELLARGRAPH_NEO4J_HOST")
# graph = py2neo.Graph(host=default_host, port=7687, user='neo4j', password='qwer')

def create_attribute_question():
    company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
    person = graph.run('MATCH (n:person) RETURN n.name as name').to_ndarray()

    questions = []

    for c in company:
        c = c[0].strip()
        question = f"{c}的收益"
        questions.append(question)
        question = f"{c}的收入"
        questions.append(question)

    for p in person:
        p = p[0].strip()
        question = f"{p}的年龄是几岁"
        questions.append(question)
        question = f"{p}多大"
        questions.append(question)
        question = f"{p}几岁"
        questions.append(question)

    return questions


def create_entity_question():
    questions = []

    for _ in range(250):
        for op in ['大于', '等于', '小于', '是', '有']:
            profit = np.random.randint(10000, 10000000, 1)[0]
            question = f"收益{op}{profit}的公司有哪些"
            questions.append(question)
            profit = np.random.randint(10000, 10000000, 1)[0]
            question = f"哪些公司收益{op}{profit}"
            questions.append(question)

    for _ in range(250):
        for op in ['大于', '等于', '小于', '是', '有']:
            profit = np.random.randint(20, 60, 1)[0]
            question = f"年龄{op}{profit}的人有哪些"
            questions.append(question)
            profit = np.random.randint(20, 60, 1)[0]
            question = f"哪些人年龄{op}{profit}"
            questions.append(question)

    return questions


def create_relation_question():
    relation = graph.run('MATCH (n)-[r]->(m) RETURN n.name as name, type(r) as r').to_ndarray()

    questions = []

    for r in relation:
        if str(r[1]) in ['董事', '监事']:
            question = f"{r[0]}{r[1]}是谁"
            questions.append(question)
        else:
            question = f"{r[0]}{r[1]}"
            questions.append(question)
            question = f"{r[0]}{r[1]}是啥"
            questions.append(question)
            question = f"{r[0]}{r[1]}什么"
            questions.append(question)

    return questions


q1 = create_entity_question()
q2 = create_attribute_question()
q3 = create_relation_question()

df = pd.DataFrame()
df['question'] = q1 + q2 + q3
df['label'] = [0] * len(q1) + [1] * len(q2) + [2] * len(q3)

df.to_csv('question_classification.csv', encoding='utf_8_sig', index=False)

data_process

import pandas as pd
import jieba
from collections import defaultdict
import numpy as np
import os

__file__ = 'kbqa'
path = os.path.dirname(__file__)


def tokenize(text, use_jieba=True):
    if use_jieba:
        res = list(jieba.cut(text, cut_all=False))
    else:
        res = list(text)
    return res


# 构建词典
def build_vocab(del_word_frequency=0):
    data = pd.read_csv('question_classification.csv')
    segment = data['question'].apply(tokenize)

    word_frequency = defaultdict(int)
    for row in segment:
        for i in row:
            word_frequency[i] += 1

    word_sort = sorted(word_frequency.items(), key=lambda x: x[1], reverse=True)  # 根据词频降序排序

    f = open('vocab.txt', 'w', encoding='utf-8')
    f.write('[PAD]' + "\n" + '[UNK]' + "\n")
    for d in word_sort:
        if d[1] > del_word_frequency:
            f.write(d[0] + "\n")
    f.close()


# 划分训练集和测试集
def split_data(df, split=0.7):
    df = df.sample(frac=1)
    length = len(df)
    train_data = df[0:length - 2000]
    eval_data = df[length - 2000:]

    return train_data, eval_data


vocab = {}
if os.path.exists(path + '/vocab.txt'):
    with open(path + '/vocab.txt', encoding='utf-8')as file:
        for line in file.readlines():
            vocab[line.strip()] = len(vocab)


# 把数据转换成index
def seq2index(seq):
    seg = tokenize(seq)
    seg_index = []
    for s in seg:
        seg_index.append(vocab.get(s, 1))
    return seg_index


# 统一长度
def padding_seq(X, max_len=10):
    return np.array([
        np.concatenate([x, [0] * (max_len - len(x))]) if len(x) < max_len else x[:max_len] for x in X
    ])


if __name__ == '__main__':
    build_vocab(5)

在这里插入图片描述

ac_automaton

import ahocorasick
from py2neo import Graph

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))
company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
relation = graph.run('MATCH ()-[r]-() RETURN distinct type(r)').to_ndarray()

ac_company = ahocorasick.Automaton()
ac_relation = ahocorasick.Automaton()

for key in enumerate(company):
    ac_company.add_word(key[1][0], key[1][0])
for key in enumerate(relation):
    ac_relation.add_word(key[1][0], key[1][0])

ac_company.make_automaton()
ac_relation.make_automaton()

# haystack = '浙江东阳东欣房地产开发有限公司的客户的供应商'
haystack = '衡水中南锦衡房地产有限公司的债券类型'
# haystack = '临沂金丰公社农业服务有限公司的分红方式'
print('question:', haystack)

subject = ''
predicate = []

for end_index, original_value in ac_company.iter(haystack):
    start_index = end_index - len(original_value) + 1
    print('公司实体:', (start_index, end_index, original_value))
    assert haystack[start_index:start_index + len(original_value)] == original_value
    subject = original_value

for end_index, original_value in ac_relation.iter(haystack):
    start_index = end_index - len(original_value) + 1
    print('关系:', (start_index, end_index, original_value))
    assert haystack[start_index:start_index + len(original_value)] == original_value
    predicate.append(original_value)

for p in predicate:
    cypher = f'''match (s:company)-[p:`{p}`]-(o) where s.name='{subject}' return o.name'''
    print(cypher)
    res = graph.run(cypher).to_ndarray()
    # print(res)
    subject = res[0][0]
print('answer:', res[0][0])

在这里插入图片描述

torch_utils

import torch
import time
import numpy as np
import six


class TrainHandler:

    def __init__(self,
                 train_loader,
                 valid_loader,
                 model,
                 criterion,
                 optimizer,
                 model_path,
                 batch_size=32,
                 epochs=5,
                 scheduler=None,
                 gpu_num=0):
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.model_path = model_path
        self.batch_size = batch_size
        self.epochs = epochs
        self.scheduler = scheduler
        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{gpu_num}')
            print('Training device is gpu:{gpu_num}')
        else:
            self.device = torch.device('cpu')
            print('Training device is cpu')
        self.model = model.to(self.device)

    def _train_func(self):
        train_loss = 0
        train_correct = 0
        for i, (x, y) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            x, y = x.to(self.device).long(), y.to(self.device)
            output = self.model(x)
            loss = self.criterion(output, y)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
            train_correct += (output.argmax(1) == y).sum().item()

        if self.scheduler is not None:
            self.scheduler.step()

        return train_loss / len(self.train_loader), train_correct / len(self.train_loader.dataset)

    def _test_func(self):
        valid_loss = 0
        valid_correct = 0
        for x, y in self.valid_loader:
            x, y = x.to(self.device).long(), y.to(self.device)
            with torch.no_grad():
                output = self.model(x)
                loss = self.criterion(output, y)
                valid_loss += loss.item()
                valid_correct += (output.argmax(1) == y).sum().item()

        return valid_loss / len(self.valid_loader), valid_correct / len(self.valid_loader.dataset)

    def train(self):
        min_valid_loss = float('inf')

        for epoch in range(self.epochs):
            start_time = time.time()
            train_loss, train_acc = self._train_func()
            valid_loss, valid_acc = self._test_func()

            if min_valid_loss > valid_loss:
                min_valid_loss = valid_loss
                torch.save(self.model, self.model_path)
                print(f'\tSave model done valid loss: {valid_loss:.4f}')

            secs = int(time.time() - start_time)
            mins = secs / 60
            secs = secs % 60

            print('Epoch: %d' % (epoch + 1), " | time in %d minutes, %d seconds" % (mins, secs))
            print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
            print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')


def torch_text_process():
    from torchtext import data

    def tokenizer(text):
        import jieba
        return list(jieba.cut(text))

    TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=20)
    LABEL = data.Field(sequential=False, use_vocab=False)
    all_dataset = data.TabularDataset.splits(path='',
                                             train='LCQMC.csv',
                                             format='csv',
                                             fields=[('sentence1', TEXT), ('sentence2', TEXT), ('label', LABEL)])[0]
    TEXT.build_vocab(all_dataset)
    train, valid = all_dataset.split(0.1)
    (train_iter, valid_iter) = data.BucketIterator.splits(datasets=(train, valid),
                                                          batch_sizes=(64, 128),
                                                          sort_key=lambda x: len(x.sentence1))
    return train_iter, valid_iter


def pad_sequences(sequences, maxlen=None, dtype='int32',
                  padding='post', truncating='pre', value=0.):
    """Pads sequences to the same length.

    This function transforms a list of
    `num_samples` sequences (lists of integers)
    into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
    `num_timesteps` is either the `maxlen` argument if provided,
    or the length of the longest sequence otherwise.

    Sequences that are shorter than `num_timesteps`
    are padded with `value` at the end.

    Sequences longer than `num_timesteps` are truncated
    so that they fit the desired length.
    The position where padding or truncation happens is determined by
    the arguments `padding` and `truncating`, respectively.

    Pre-padding is the default.

    # Arguments
        sequences: List of lists, where each element is a sequence.
        maxlen: Int, maximum length of all sequences.
        dtype: Type of the output sequences.
            To pad sequences with variable length strings, you can use `object`.
        padding: String, 'pre' or 'post':
            pad either before or after each sequence.
        truncating: String, 'pre' or 'post':
            remove values from sequences larger than
            `maxlen`, either at the beginning or at the end of the sequences.
        value: Float or String, padding value.

    # Returns
        x: Numpy array with shape `(len(sequences), maxlen)`

    # Raises
        ValueError: In case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    num_samples = len(sequences)

    lengths = []
    for x in sequences:
        try:
            lengths.append(len(x))
        except TypeError:
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))

    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
    if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
        raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
                         "You should set `dtype=object` for variable length strings."
                         .format(dtype, type(value)))

    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" '
                             'not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s '
                             'is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x


if __name__ == '__main__':
    torch_text_process()

text_cnn

import torch
from torch import nn


class TextCNN(nn.Module):
    def __init__(self, vocab_len, embedding_size, n_class):
        super().__init__()

        self.embedding = nn.Embedding(vocab_len, embedding_size)

        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[3, embedding_size])
        self.cnn2 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[4, embedding_size])
        self.cnn3 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[5, embedding_size])

        self.max_pool1 = nn.MaxPool1d(kernel_size=8)
        self.max_pool2 = nn.MaxPool1d(kernel_size=7)
        self.max_pool3 = nn.MaxPool1d(kernel_size=6)

        self.drop_out = nn.Dropout(0.2)
        self.full_connect = nn.Linear(300, n_class)

    def forward(self, x):
        embedding = self.embedding(x)
        embedding = embedding.unsqueeze(1)

        cnn1_out = self.cnn1(embedding).squeeze(-1)
        cnn2_out = self.cnn2(embedding).squeeze(-1)
        cnn3_out = self.cnn3(embedding).squeeze(-1)

        out1 = self.max_pool1(cnn1_out)
        out2 = self.max_pool2(cnn2_out)
        out3 = self.max_pool3(cnn3_out)

        out = torch.cat([out1, out2, out3], dim=1).squeeze(-1)

        out = self.drop_out(out)
        out = self.full_connect(out)
        # out = torch.softmax(out, dim=-1).squeeze(dim=-1)
        return out

train

import torch
from torch.utils.data import TensorDataset, DataLoader
from kbqa.torch_utils import TrainHandler
from kbqa.data_process import *
from kbqa.text_cnn import TextCNN

# df = pd.read_csv('question_classification.csv')
# print(df['label'].value_counts())


def load_data(batch_size=32):
    df = pd.read_csv('kbqa/question_classification.csv')
    train_df, eval_df = split_data(df)
    train_x = df['question']
    train_y = df['label']
    valid_x = eval_df['question']
    valid_y = eval_df['label']

    train_x = padding_seq(train_x.apply(seq2index))
    train_y = np.array(train_y)
    valid_x = padding_seq(valid_x.apply(seq2index))
    valid_y = np.array(valid_y)

    train_data_set = TensorDataset(torch.from_numpy(train_x),
                                   torch.from_numpy(train_y))
    valid_data_set = TensorDataset(torch.from_numpy(valid_x),
                                   torch.from_numpy(valid_y))
    train_data_loader = DataLoader(dataset=train_data_set, batch_size=batch_size, shuffle=True)
    valid_data_loader = DataLoader(dataset=valid_data_set, batch_size=batch_size, shuffle=True)

    return train_data_loader, valid_data_loader


train_loader, valid_loader = load_data(batch_size=64)

model = TextCNN(1289, 256, 3)# 原model = TextCNN(1141, 256, 3),1289根据vocat.txt行数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
model_path = 'text_cnn.p'
handler = TrainHandler(train_loader,valid_loader,model,
                       criterion,
                       optimizer,
                       model_path,
                       batch_size=32,
                       epochs=5,
                       scheduler=None,
                       gpu_num=0)
handler.train()

在这里插入图片描述

main

import torch
from kbqa.data_process import *
import ahocorasick
from py2neo import Graph
import re
import traceback


model = torch.load('kbqa/text_cnn.p', map_location=torch.device('cpu'))
model.eval()

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))
company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
person = graph.run('MATCH (n:person) RETURN n.name as name').to_ndarray()
relation = graph.run('MATCH ()-[r]-() RETURN distinct type(r)').to_ndarray()

ac_company = ahocorasick.Automaton()
ac_person = ahocorasick.Automaton()
ac_relation = ahocorasick.Automaton()

for key in enumerate(company):
    ac_company.add_word(key[1][0], key[1][0])
for key in enumerate(person):
    ac_person.add_word(key[1][0], key[1][0])
for key in enumerate(relation):
    ac_relation.add_word(key[1][0], key[1][0])
ac_relation.add_word('年龄', '年龄')
ac_relation.add_word('年纪', '年纪')
ac_relation.add_word('收入', '收入')
ac_relation.add_word('收益', '收益')

ac_company.make_automaton()
ac_person.make_automaton()
ac_relation.make_automaton()


def classification_predict(s):
    s = seq2index(s)
    s = torch.from_numpy(padding_seq([s])).long() #.cuda().long()
    out = model(s)
    out = out.cpu().data.numpy()
    print(out)
    return out.argmax(1)[0]


def entity_link(text):
    subject = []
    subject_type = None
    for end_index, original_value in ac_company.iter(text):
        start_index = end_index - len(original_value) + 1
        print('实体:', (start_index, end_index, original_value))
        assert text[start_index:start_index + len(original_value)] == original_value
        subject.append(original_value)
        subject_type = 'company'
    for end_index, original_value in ac_person.iter(text):
        start_index = end_index - len(original_value) + 1
        print('实体:', (start_index, end_index, original_value))
        assert text[start_index:start_index + len(original_value)] == original_value
        subject.append(original_value)
        subject_type = 'person'

    return subject[0], subject_type


def get_op(text):
    pattern = re.compile(r'\d+')
    num = pattern.findall(text)
    op = None
    if '大于' in text:
        op = '>'
    elif '小于' in text:
        op = '<'
    elif '等于' in text or '是' in text:
        op = '='
    return op, float(num[0])


def kbqa(text):
    print('*' * 100)
    cls = classification_predict(text)
    print('question type:', cls)
    res = ''

    if cls == 0:
        op, num = get_op(text)
        subject_type = ''
        attribute = ''
        for w in ['年龄', '年纪']:
            if w in text:
                subject_type = 'person'
                attribute = 'age'
                break
        for w in ['收入', '收益']:
            if w in text:
                subject_type = 'company'
                attribute = 'profit'
                break
        cypher = f'match (n:{subject_type}) where n.{attribute}{op}{num} return n.name'
        print(cypher)
        res = graph.run(cypher).to_ndarray()
    elif cls == 1:
        # 查询属性
        subject, subject_type = entity_link(text)
        predicate = ''
        for w in ['年龄', '年纪']:
            if w in text and subject_type == 'person':
                predicate = 'age'
                break
        for w in ['收入', '收益']:
            if w in text and subject_type == 'company':
                predicate = 'profit'
                break
        cypher = f'''match (n:{subject_type}) where n.name='{subject}' return n.{predicate}'''
        print(cypher)
        res = graph.run(cypher).to_ndarray()
    elif cls == 2:
        subject = ''
        for end_index, original_value in ac_company.iter(text):
            start_index = end_index - len(original_value) + 1
            print('公司实体:', (start_index, end_index, original_value))
            assert text[start_index:start_index + len(original_value)] == original_value
            subject = original_value
        predicate = []
        for end_index, original_value in ac_relation.iter(text):
            start_index = end_index - len(original_value) + 1
            print('关系:', (start_index, end_index, original_value))
            assert text[start_index:start_index + len(original_value)] == original_value
            predicate.append(original_value)
        for i, p in enumerate(predicate):
            cypher = f'''match (s:company)-[p:`{p}`]->(o) where s.name='{subject}' return o.name'''
            print(cypher)
            res = graph.run(cypher).to_ndarray()
            subject = res[0][0]
            if i == len(predicate) - 1:
                break
            new_index = text.index(p) + len(p)
            new_question = subject + str(text[new_index:])
            print('new question:', new_question)
            res = kbqa(new_question)
            break
    return res


if __name__ == '__main__':

    while 1:
        try:
            text = input('text:')
            res = kbqa(text)
            print(res)
        except:
            print(traceback.format_exc())

在这里插入图片描述

图谱问答实战小结

模型的整体结构:ac自动机+找实体+多跳问答

ps:这里实体没有多个,用不到实体消歧,这里我们使用的是匹配的方法


学习的参考资料:
七月在线NLP高级班

代码参考:
https://github.com/terrifyzhao/neo4j_graph

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1595167.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习图像处理基础工具——opencv 实战信用卡数字识别

任务 信用卡数字识别 穿插之前学的知识点 形态学操作 模板匹配 等 总体流程与方法 1.有一个模板 2 用轮廓检测把模板中数字拿出来 外接矩形&#xff08;模板和输入图像的大小要一致 &#xff09;3 一系列预处理操作 问题的解决思路 1.分析准备&#xff1a;准备模板&#…

libbpf-bootstrap库的代码结构介绍(用户层接口介绍),编译链接语句详细介绍,.skel.h文件介绍+示例,bpf程序的后续处理+文件关系总结

目录 libbpf-bootstrap 代码结构介绍 用户层函数 编译 查看 生成内核层的.o文件 第一模块 第二模块 第三模块 第四模块 第五模块 生成辅助文件(.skel.h) 介绍 示例 生成代码层的.o文件 第一模块 第二模块 第三模块 链接出可执行文件 后续总结 libbpf-bootst…

舒欣上门预约系统源码-按摩预约/家政预约全行业适用-小程序/h5/app

上门预约或者到店预约均可&#xff0c;家政&#xff0c;按摩&#xff0c;等等上门类行业均可适用。&#xff08;后台的技师及前台技师这两个字是可以更改的&#xff0c;例如改成家政老师&#xff0c;保洁&#xff0c;等等&#xff09; 视频教程是演示搭建的小程序端&#xff0c…

stm32报错问题集锦

PS&#xff1a;本文负责记录本人日常遇到的报错问题&#xff0c;以及问题描述、原因以及解决办法等&#xff0c;解决办法百分百亲测有效。本篇会不定期更新&#xff0c;更新频率就看遇到的问题多不多了 更换工程芯片型号 问题描述 例程最开始用的芯片型号是STM32F103VE&#…

Abstract Factory抽象工厂模式详解

模式定义 提供一个创建一系列相关或互相依赖对象的接口&#xff0c;而无需指定它们具体的类。 代码示例 public class AbstractFactoryTest {public static void main(String[] args) {IDatabaseUtils iDatabaseUtils new OracleDataBaseUtils();IConnection connection …

微信登录功能-保姆级教学

目录 一、使用组件 二、登录功能 2.1 步骤 2.2 首先找到网页权限 复制demo 代码 这里我们需要修改两个参数 三、前端代码 3.1 api 里weiXinApi.ts 3.2 api里的 index.ts 3.3 pinia.ts 3.4 My.vue 四、后端代码 4.1 WeiXinController 4.2 Access_Token.Java 4.3 We…

SHARE 203S PRO:倾斜摄影相机在地灾救援中的应用

在地质灾害的紧急关头&#xff0c;救援队伍面临的首要任务是迅速而准确地掌握灾区的地理信息。这时&#xff0c;倾斜摄影相机成为了救援测绘的利器。SHARE 203S PRO&#xff0c;这款由深圳赛尔智控科技有限公司研发的五镜头倾斜摄影相机&#xff0c;以其卓越的性能和功能&#…

OpenHarmony实战开发-FaultLoggerd组件。

简介 Faultloggerd部件是OpenHarmony中C/C运行时崩溃临时日志的生成及管理模块。面向基于 Rust 开发的部件&#xff0c;Faultloggerd 提供了Rust Panic故障日志生成能力。系统开发者可以在预设的路径下找到故障日志&#xff0c;定位相关问题。 架构 Native InnerKits 接口Sig…

新手教程 | 2024年最新Vmware17安装教程及许可证(详细图文)

目录 前言&#xff1a; 一、VMware Workstation 17 Pro 简介 二、下载安装&#xff08;以Windows为例&#xff09; 三、许可证 四、检查是否安装成功 前言&#xff1a; 重新装电脑后&#xff0c;安装虚拟机 一、VMware Workstation 17 Pro 简介 VMware Workstation 17 …

问题整理【2024-04-08】

一、关于MYSQL死锁问题 1.1 源由 一次上线过程中&#xff0c;遇到了MySQL死锁的问题…… 1.2 分析 1.2.1 前置知识 ​ 首先要知道&#xff0c;MySQL是一个多线程的数据库管理系统&#xff0c;查询是通常并发执行的&#xff0c;可以同时处理多个查询请求&#xff0c;并且My…

文献阅读:LESS: Selecting Influential Data for Targeted Instruction Tuning

文献阅读&#xff1a;LESS: Selecting Influential Data for Targeted Instruction Tuning 1. 文章简介2. 方法介绍 1. Overview2. 原理说明 1. SGD上的定义2. Adam上的定义 3. 具体实现 1. Overview1. LoRA使用2. 数据选择3. LESS-T 3. 实验考察 & 结论 1. 实验设计2. 主…

SPI 设备驱动编写流程:创建SPI节点以及SPI设备节点(在设备树文件中)

一. 简介 SPI 驱动框架和 I2C 很类似&#xff0c;都分为主机控制器驱动和设备驱动。 SPI主机控制器的驱动一般是芯片半导体厂商写好了&#xff0c;我们要编写的是SPI设备驱动代码。 本文开始来学习SPI设备驱动的编写流程&#xff08;前提是支持设备树的情况&#xff09;。 二…

Google最新论文: 复杂的 Prompt 如何更好的调试?

本文介绍了Sequence Salience&#xff0c;这是一个专为调试复杂的大模型提示而设计的系统。该系统利用广泛使用的显著性方法&#xff0c;支持文本分类和单标记预测&#xff0c;并将其扩展到可处理长文本的调试系统。现有的工具往往不足以处理长文本或复杂提示的调试需求。尽管存…

bugku-web-文件上传

提示他的名字是margin&#xff0c;给他一个图片文件&#xff0c;不要php文件 上传一句话木马的图片 抓包&#xff0c;后缀改为php 提示无效文件&#xff0c;即后台还会检测一次后缀 测试后台系统 为linux系统 开始绕过 截断绕过 上传成功&#xff0c;但是会变为jpg 开始分析…

黑马苍穹外卖--再来一单(stream流转换、赋值与收集映射)

1.首先明确一下业务规则: 业务规则&#xff1a; 再来一单就是将原订单中的商品重新加入到购物车中 2.产品页面原型和开发接口文档 3.业务层逻辑代码开发 3.1 查询方向 我们要明确的是: 再来一单就是将原订单中的商品重新加入到购物车中------直接把商品加入到购物车&#…

Jmeter三个常用组件

Jmeter三个常用组件 一、线程组二、 HTTP请求三、查看结果树 线程组&#xff1a;jmeter是基于线程来运行的&#xff0c;线程组主要用来管理线程的数量&#xff0c;线程的执行策略。 HTTP请求&#xff1a;HTTP请求是jmeter接口测试的核心部分&#xff0c;主要使用HTTP取样器来发…

微信小程序实现预约生成二维码

业务需求&#xff1a;点击预约按钮即可生成二维码凭码入校参观~ 一.创建页面 如下是博主自己写的wxml&#xff1a; <swiper indicator-dots indicator-color"white" indicator-active-color"blue" autoplay interval"2000" circular > &…

pyqt和opencv结合01:读取图像、显示

在这里插入图片描述 1 、opencv读取图像用于pyqt显示 # image cv2.imread(file_path)image cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 将图像转换为 Qt 可接受的格式height, width, channel image.shapebytes_per_line 3 * widthq_image QImage(image.data, width, hei…

2440栈的实现类型、b系列指令、汇编掉用c、c调用汇编、切换工作模式、初始化异常向量表、中断处理、

我要成为嵌入式高手之4月11日51ARM第六天&#xff01;&#xff01; ———————————————————————————— b指令 标签&#xff1a;表示这条指令的名称&#xff0c;可跳转至标签 b指令&#xff1a;相当于goto&#xff0c;可随意跳转 如&#xff1a;fini…

Finetuning vs. Prompting:大语言模型两种使用方式

目录 前言1. 对于大型语言模型的两种不同期待2. Finetune(专才)3. Prompt(通才)3.1 In-context Learning3.2 Instruction-tuning3.3 Chain of Thought(COT) Prompting3.4 用机器来找Prompt 总结参考 前言 这里和大家分享下关于大语言模型的两种使用方式&#xff0c;一种是 Fine…