Huggingface 超详细介绍

news2024/11/27 4:30:21

Hugging face 起初是一家总部位于纽约的聊天机器人初创服务商,他们本来打算创业做聊天机器人,然后在github上开源了一个Transformers库,虽然聊天机器人业务没搞起来,但是他们的这个库在机器学习社区迅速大火起来。目前已经共享了超100,000个预训练模型,10,000个数据集,变成了机器学习界的github。

其之所以能够获得如此巨大的成功,一方面是让我们这些甲方企业的小白,尤其是入门者也能快速用得上科研大牛们训练出的超牛模型。另一方面是,这种特别开放的文化和态度,以及利他利己的精神特别吸引人。huggingface上面很多业界大牛也在使用和提交新模型,这样我们就是站在大牛们的肩膀上工作,而不是从头开始,当然我们也没有大牛那么多的计算资源和数据集。

在国内huggingface也是应用非常广泛,一些开源框架本质上就是调用transfomer上的模型进行微调(当然也有很多大牛在默默提供模型和数据集)。很多nlp工程师招聘的条目上也明摆着要求熟悉huggingface transformer库的使用。简单介绍了他们多么牛逼之后,我们看看huggingface怎么玩吧。因为他既提供了数据集,又提供了模型让你随便调用下载,因此入门非常简单。你甚至不需要知道什么是GPT,BERT就可以用他的模型了(当然看看我写的BERT简介还是十分有必要的)。下面初步介绍下huggingface里面都有什么,以及怎么调用BERT模型做个简单的任务。

1.我们能从huggingface获得什么

内容

huggingface的官方网站:http://www.huggingface.co. 在这里主要有以下大家需要的资源。

Datasets:数据集,以及数据集的下载地址
Models:各个预训练模型
course:免费的nlp课程,可惜都是英文的
docs:文档
下图是一张来自于官网的Transformers发展谱系图,短短3,4年就发展出了庞大家族。如果你不是学术界的代表,你无需详细搞懂他们的原理,就可以直接使用这些科研界最先进的模型,下一章我们来介绍如何简单的拿这些模型进行nlp任务。
在这里插入图片描述

一个有趣的现象

在NLP领域,在hugging face上面数据集和预训练模型的数量以英语为最为众多,远超其他国家的总和(见下图)。就预训练模型来说,排名第二的是汉语。就数据集来说,汉语远远少于英语,也少于法,德,西班牙等语言,甚至少于阿拉伯语和波兰语。这严重跟我想象中的AI超级大国及其不匹配。我想一方面因为数据集的积累都需要很多年,中文常用的(PKU,MSRA)数据集都是十几年前留下的,而我们AI和经济的崛起也不过是最近十年的事情。另一方面,数据集都是大价钱整理出来的,而且可以不断的利用他产生新的模型,这样的大杀器怎可随意公布。发布预训练模型可以带来论文,数据集可啥也带不来,基本上中日韩等的数据集明显偏少。
在这里插入图片描述

2.小试牛刀

接下来的内容参考了下面的内容,初步带你入门huggingface,简单了解如何调用BERT模型:

dxzmpk写的教程 https://www.cnblogs.com/dongxiong/p/12763923.html
huggingface官方教程:https://huggingface.co/docs/transformers/model_doc/bert
有必要阅读的论文:
https://arxiv.org/abs/1706.03762

https://arxiv.org/abs/1706.03762

如果看英文论文太费劲,可以参考我写的学习笔记:

attention与sef-attention介绍

transformer模型结构介绍

Bert简单介绍

安装

transformers库github地址在:https://github.com/huggingface/transformers

安装方法,在命令行执行(conda的话在anaconda propmt):

pip install transformers # 安装最新的版本
pip install transformers == 4.0 # 安装指定版本
#如果你是conda的话
conda install -c huggingface transformers  # 4.0以后的版本才会有

测试下安装是否成功

from transformers import pipeline  # 引入一个pipeline试试看,如果不报错说明安装成功、
# 因为NLP通常是多个任务顺序而成,所以通常使用pipeline,流水线工作

模型的组成

一般transformer模型有三个部分组成:1.tokennizer,2.Model,3.Post processing。如下图所示,图中第二层和第三层是每个部件的输入/输出以及具体的案例。我们可以看到三个部分的具体作用:Tokenizer就是把输入的文本做切分,然后变成向量,Model负责根据输入的变量提取语义信息,输出logits;最后Post Processing根据模型输出的语义信息,执行具体的nlp任务,比如情感分析,文本自动打标签等;可见Model是其中的核心部分,Model又可以分为三种模型,针对不同的NLP任务,需要选取不同的模型类型:Encoder模型(如Bert,常用于句子分类、命名实体识别(以及更普遍的单词分类)和抽取式问答。),Decoder模型(如GPT,GPT2,常用于文本生成),以及sequence2sequence模型(如BART,常用于摘要,翻译,生成性问答等)
在这里插入图片描述

说了很多理论的内容,我们可以在huggingface的官网,随便找一个预训练模型具体看看包含哪些文件。在这里我举了一个中文的例子”Bert-base-Chinese“(中文还有其他很优秀的预训练模型,比如哈工大和科大讯飞提供的:roberta-wwm-ext,百度提供的:ernie)。这个模型据说是根据中文维基百科内容训练的,因此语义内容可能不是足够丰富,毕竟其他大佬们提供的数据更多。
在这里插入图片描述

readme一般是模型的介绍,包括使用方法都会放到里面,不介绍了。其他最重要的组成部分,大概分为三类:

1. config

控制模型的名称、最终输出的样式、隐藏层宽度和深度、激活函数的类别等。这些参数我补齐了说明,对于初学者来说,大家一般不需要调整。这些参数都可以通过configuration类更改。

{
  "architectures": [
    "BertForMaskedLM"                           # 模型的名称
  ],
  "attention_probs_dropout_prob": 0.1,          # 注意力机制的 dropout,默认为0.1
  "directionality": "bidi",                     # 文字编码方向采用bidi算法
  "hidden_act": "gelu",                         # 编码器内激活函数,默认"gelu",还可为"relu"、"swish"或 "gelu_new"
  "hidden_dropout_prob": 0.1,                   # 词嵌入层或编码器的 dropout,默认为0.1
  "hidden_size": 768,                           # 编码器内隐藏层神经元数量,默认768
  "initializer_range": 0.02,                    # 神经元权重的标准差,默认为0.02
  "intermediate_size": 3072,                    # 编码器内全连接层的输入维度,默认3072
  "layer_norm_eps": 1e-12,                      # layer normalization 的 epsilon 值,默认为 1e-12
  "max_position_embeddings": 512,               # 模型使用的最大序列长度,默认为512
  "model_type": "bert",                         # 模型类型是bert
  "num_attention_heads": 12,                    # 编码器内注意力头数,默认12
  "num_hidden_layers": 12,                      # 编码器内隐藏层层数,默认12
  "pad_token_id": 0,                            # pad_token_id 未找到相关解释
  "pooler_fc_size": 768,                        # 下面应该是pooler层的参数,本质是个全连接层,作为分类器解决序列级的NLP任务
  "pooler_num_attention_heads": 12,             # pooler层注意力头,默认12
  "pooler_num_fc_layers": 3,                    # pooler 连接层数,默认3
  "pooler_size_per_head": 128,                  # 每个注意力头的size
  "pooler_type": "first_token_transform",       # pooler层类型,网上介绍很少
  "type_vocab_size": 2,                         # 词汇表类别,默认为2
  "vocab_size": 21128                           # 词汇数,bert默认30522,这是因为bert以中文字为单位进入输入
}

2. tokenizer(包含三个文件)

这些文件是tokenizer类生成的,或者处理的,只是处理文本,不涉及任何向量操作。

vocab.txt是词典文件(打开就是单个字符,我这里用的是bert-base-chinsese,可以看到里面都是保留符号和单个汉字索引,字符)

tokenizer.json和config是分词的配置文件,根据vocab信息和你的设置更新,里面把vocab都按顺序做了索引,将来可以根据编码生成one-hot向量,然后跟embeding训练的矩阵相乘,就可以得到该字符的向量。下图是tokenizer.json内容。
在这里插入图片描述

模型文件一般是tensor flow(上图中的h5文件)和py-torch(上图中的bin文件)的都有,因为作者只是单纯的在学习torch,所以以后的文章都只介绍torch。

3. BERT模型的使用

介绍完了模型库都有哪些内容,下面我们可以导入模型试一试怎么使用啦。

3.1 导入模型

利用官方的hub导入模型;下面导入了一个BertModel;在官方的教程中推进使用pipeline导入模型的方法;

import torch
from transformers import BertModel, BertTokenizer, BertConfig
# 首先要import进来
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
config = BertConfig.from_pretrained('bert-base-chinese')
config.update({'output_hidden_states':True}) # 这里直接更改模型配置
model = BertModel.from_pretrained("bert-base-chinese",config=config)

利用pipeline的方式也是一样的可以导入模型哈,方式如下:

from transformers import AutoModel
checkpoint = "bert-base-chinese"
model = AutoModel.from_pretrained(checkpoint)

因为huggingface官网在国外,自动下载可能比较费劲,笔者在公司下载速度还是非常快的。

默认下载地址在这里:

1)使用 Windows 模型保存的路径在 C:\Users[用户名].cache\torch\transformers
目录下,根据模型的不同下载的东西也不相同 2)使用 Linux 模型保存的路径在 ~/.cache/torch/transformers/
目录下

如果自动下载总是中断的话,可以考虑用国内的源,或者手工下载之后指定位置。(huggingface官网,选择models菜单,然后搜索自己想要的模型,然后把里面的文件下载下来,其中体积较大的有tf的有torch的,根据自己需要下载)。

import transformers
MODEL_PATH = r"D:\\test\\bert-base-chinese"
# 导入模型
tokenizer = transformers.BertTokenizer.from_pretrained(r"D:\\test\\bert-base-chinese\\bert-base-chinese-vocab.txt") 
# 导入配置文件
model_config = transformers.BertConfig.from_pretrained(MODEL_PATH)
# 修改配置
model_config.output_hidden_states = True
model_config.output_attentions = True
# 通过配置和路径导入模型
model = transformers.BertModel.from_pretrained(MODEL_PATH,config = model_config)
3.2 使用模型

上一步我们已经把模型加载进来了,在这里,尝试一下这个模型怎么样,看看能不能把相关的语义带入进来。我们之前文章介绍了bert的两个任务(MLM和NSP),这一节,我们一起测试这两个任务的效果。首先我们逐步来看看BERT每个部分的输出都是什么,我们可以看看哪些好玩的东西。

tokenizer
上面代码可以看到他实例化了BertTokenizer类,它是基于WordPiece方法的,先看看他有哪些参数:

( vocab_file,do_lower_case = True,do_basic_tokenize = True,never_split
= None,unk_token = ‘[UNK]’,sep_token = ‘[SEP]’,pad_token = ‘[PAD]’,cls_token = ‘[CLS]’,mask_token =
‘[MASK]’,tokenize_chinese_chars = True,strip_accents = None,**kwargs )

vocab_file:这里是放置词典的地址,do_lower_case,是否都变成小写,默认是True哦,do_basic_tokenize,做wordpiece之前是否要做basic tokenize;下面的都是一些关键字的确认。还有就是是否分开中文字符,因为bert是面向英文的所有有这些设置,一般不用改,当然我们这里的案例也只是读取了预训练模型。

我们来个小案例看看,分出来的字符是什么样子的。示例如下,可以看出BERT对中文是字符级别的分词,对待英文是到sub-word级别的:

# 上文的示例代码已经实例话了,这里不重复了;
print(tokenizer.encode("生活的真谛是美和爱"))  # 对于单个句子编码
print(tokenizer.encode_plus("生活的真谛是美和爱","说的太好了")) # 对于一组句子编码
# 输出结果如下:
[101, 4495, 3833, 4638, 4696, 6465, 3221, 5401, 1469, 4263, 102]
{'input_ids': [101, 4495, 3833, 4638, 4696, 6465, 3221, 5401, 1469, 4263, 102, 6432, 4638, 1922, 1962, 749, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

# 也可以直接这样用
sentences = ['网络安全开发分为三个层级',
             '车辆系统层级网络安全开发',
             '车辆功能层级网络安全开发',
             '车辆零部件层级网络安全开发',
             '测试团队根据车辆网络安全目标制定测试技术要求及测试计划',
             '测试团队在网络安全团队的支持下,完成确认测试并编制测试报告',
             '在车辆确认结果的基础上,基于合理的理由,确认在设计和开发阶段识别出的所有风险均已被接受',]
test1 = tokenizer(sentences)

print(test1)  # 对列表encoder
print(tokenizer("网络安全开发分为三个层级"))  # 对单个句子encoder

我们来看一下这个输出:对于单个句子是上面那种,他只输出句子tok之后的id,我们注意到已经加好[CLS],[SEP]等标识符了;(查询tokenizer可知,101是[CLS],102是[SEP])除了input_ids之外,还自动编码了token_type_ids,attention_mask

当然除了这种直接调用模型之外,还可以利用pipeline方法来

model

model实例化了BertModel类,除了初始的 Bert、GPT 等基本模型,针对不同的下游任务,定义了 BertForQuestionAnswering,BertForMultiChoice,BertForNextSentencePrediction 以及 BertForSequenceClassification 等下游任务模型。模型导出时将生成 config.json 和 pytorch_model.bin 参数文件,这两个文件前面已将介绍了,一个是配置文件一个是torch训练后save的文件。那下面我们来看看这个怎么使用吧。因为中文是字符级的tok,所以做MLM任务不是很理想,所以下面我用英文的base模型示例一个MLM任务;

from transformers import pipeline
# 运行该段代码要保障你的电脑能够上网,会自动下载预训练模型,大概420M
unmasker = pipeline("fill-mask",model = "bert-base-uncased")  # 这里引入了一个任务叫fill-mask,该任务使用了base的bert模型
unmasker("The goal of life is [MASK].", top_k=5) # 输出mask的指,对应排名最前面的5个,也可以设置其他数字
# 输出结果如下,似乎都不怎么有效哈。
[{'score': 0.10933303833007812,
  'token': 2166,
  'token_str': 'life',
  'sequence': 'the goal of life is life.'},
 {'score': 0.03941883146762848,
  'token': 7691,
  'token_str': 'survival',
  'sequence': 'the goal of life is survival.'},
 {'score': 0.032930608838796616,
  'token': 2293,
  'token_str': 'love',
  'sequence': 'the goal of life is love.'},
 {'score': 0.030096106231212616,
  'token': 4071,
  'token_str': 'freedom',
  'sequence': 'the goal of life is freedom.'},
 {'score': 0.024967126548290253,
  'token': 17839,
  'token_str': 'simplicity',
  'sequence': 'the goal of life is simplicity.'}]

后处理

后处理通常要根据你选择的模型来确定,一般模型的输出是logits,其包含我们需要的语义信息,然后后处理是经过一个激活函数输出我们可以使用的向量,比如softmax层做二分类,会输出对应两个标签的概率值,然后就可以轻松转化为我们需要的信息啦。

3.3 再来试一试我们情感分析的任务;

BERT论文中介绍了自己在推理,问答等多个任务中的提升,在这里我们只介绍一个简单的情感分析任务。

数据集
在我的第一篇笔记里,基于双向的LSTM搭建了一个情感分析的示例,当时使用的是IMDB电影评论(一共有5万条,正负面评论各25000条)。代码已经对数据集进行了封装和整理,在这里就不重复介绍了。

下游任务训练
情感分析任务在huggingface中称之为:Text Classification。根据huggingface任务方面的定义,在这里我们延展介绍一下那些常见的任务是属于文本分类的:

NLI(Natural Language Infenrence),或称之为或Recognizing Textual Entailment(RTE)蕴含文本识别。针对这类问题有一系列的数据集,以及基于这些数据集训练出来的模型,常见的有:
QNLI,QNLI是从另一个权威的QA数据集The Stanford Question Answering Dataset(斯坦福问答数据集, SQuAD 1.0)转换而来的。SQuAD 1.0是由问题-段落对组成的问答数据集,其中段落来自Wiki,段落中的一个句子包含问题的答案。通过将问题和上下文(即维基百科段落)中的每一句话进行组合,并过滤掉词汇重叠比较低的句子对就得到了QNLI中的句子对。本质是一个判断蕴含还是不蕴含的二分类问题;

MNLI,多类型自然语言推理数据库,是一个自然语言推断任务,数据集是通过众包方式对句子对进行文本蕴含标注的集合。给定前提语句和假设语句,任务是预测前提语句是否包含假设(entailment)、与假设矛盾(contradiction)或者两者都不(中立,neutral)。本质是一个三分类问题:判断是有前提,无前提,还是中立的;

等等还有其他的数据集,包括一些对抗性的;

情感分析(Sentiment Analysis):本质是一个二分类的问题,给定一个文本判断是正面的(POS),还是负面的(NEG)
Quora Question Pairs:给出两个问题,判断这两个问题的含义是否一致;属于一个二分类的问题;他的数据集是quroa问题队,也被收录在GLUE内。
语法校核-Grammatical Correctness:评估一个句子的语法可接受性,二分类任务,结果是可接受或者不可接受;常用的数据集是: Corpus of Linguistic Acceptability (CoLA)
说了那么多题外话,我们回过来来看我们本次的任务,他是Text classification中的情感分析任务,是一个二分类的任务,给出一段话,从标签{“POS”,‘NEG’}中选择一个最合适的。

沿用之前的数据处理代码,我们这里只更改模型;

好了废话不多说了,上代码,大家去看详细的代码注释吧,由于设置多个epoch和较大的batchsize,我的电脑完全带动不起来,大家放到gpu计算,记得to device到GPU上。拷贝下来直接就能用。

# _*_ coding:utf-8 _*_
# 利用深度学习做情感分析,基于Imdb 的50000个电影评论数据进行;

import torch
from torch.utils.data import DataLoader,Dataset
import os
import re
from random import sample
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
 
# 路径需要根据情况修改,要看你把数据下载到哪里了
# 数据下载地址在斯坦福官网,网上搜索就有
data_base_path = r"./imdb_test/aclImdb"

# 这个里面是存储你训练出来的模型的,现在是空的
model_path = r"./imdb_test/aclImdb/mode"
        
#1. 准备dataset,这里写了一个数据读取的类,并把数据按照不同的需要进行了分类;
class ImdbDataset(Dataset):
    def __init__(self,mode,testNumber=10000,validNumber=5000):

        # 在这里我做了设置,把数据集分成三种形式,可以选择 “train”默认返回全量50000个数据,“test”默认随机返回10000个数据,
        # 如果是选择“valid”模式,随机返回相应数据
        super(ImdbDataset,self).__init__()

        # 读取所有的训练文件夹名称
        text_path =  [os.path.join(data_base_path,i)  for i in ["test/neg","test/pos"]]
        text_path.extend([os.path.join(data_base_path,i)  for i in ["train/neg","train/pos"]])

        if mode=="train":
            self.total_file_path_list = []
            # 获取训练的全量数据,因为50000个好像也不算大,就没设置返回量,后续做sentence的时候再做处理
            for i in text_path:
                self.total_file_path_list.extend([os.path.join(i,j) for j in os.listdir(i)])
        if mode=="test":
            self.total_file_path_list = []
            # 获取测试数据集,默认10000个数据
            for i in text_path:
                self.total_file_path_list.extend([os.path.join(i,j) for j in os.listdir(i)])
            self.total_file_path_list=sample(self.total_file_path_list,testNumber)
       
        if mode=="valid":
            self.total_file_path_list = []
            # 获取验证数据集,默认5000个数据集
            for i in text_path:
                self.total_file_path_list.extend([os.path.join(i,j) for j in os.listdir(i)])
            self.total_file_path_list=sample(self.total_file_path_list,validNumber)
   
    def tokenize(self,text):
    
        # 具体要过滤掉哪些字符要看你的文本质量如何
       
        # 这里定义了一个过滤器,主要是去掉一些没用的无意义字符,标点符号,html字符啥的
        fileters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@'
            ,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
        # sub方法是替换
        text = re.sub("<.*?>"," ",text,flags=re.S)	# 去掉<...>中间的内容,主要是文本内容中存在<br/>等内容
        text = re.sub("|".join(fileters)," ",text,flags=re.S)	# 替换掉特殊字符,'|'是把所有要匹配的特殊字符连在一起
        return text	# 返回文本

    def __getitem__(self, idx):
        cur_path = self.total_file_path_list[idx]
		# 返回path最后的文件名。如果path以/或\结尾,那么就会返回空值。即os.path.split(path)的第二个元素。
        # cur_filename返回的是如:“0_3.txt”的文件名
        cur_filename = os.path.basename(cur_path)
        # 标题的形式是:3_4.txt	前面的3是索引,后面的4是分类
        # 如果是小于等于5分的,是负面评论,labei给值维1,否则就是1
        labels = []
        sentences = []
        if int(cur_filename.split("_")[-1].split(".")[0]) <= 5 :
            label = 0
        else:
            label = 1
        # temp.append([label])
        labels.append(label)
        text = self.tokenize(open(cur_path,encoding='UTF-8').read().strip()) #处理文本中的奇怪符号
        sentences.append(text)
        # 可见我们这里返回了一个list,这个list的第一个值是标签0或者1,第二个值是这句话;
        return sentences,labels
 
    def __len__(self):
        return len(self.total_file_path_list)
    
# 2. 这里开始利用huggingface搭建网络模型
# 这个类继承再nn.module,后续再详细介绍这个模块
# 
class BertClassificationModel(nn.Module):
    def __init__(self,hidden_size=768):
        super(BertClassificationModel, self).__init__()
        # 这里用了一个简化版本的bert
        model_name = 'distilbert-base-uncased'

        # 读取分词器
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
        
        # 读取预训练模型
        self.bert = BertModel.from_pretrained(pretrained_model_name_or_path=model_name)

        for p in self.bert.parameters(): # 冻结bert参数
                p.requires_grad = False
        self.fc = nn.Linear(hidden_size,2)

    def forward(self, batch_sentences):   # [batch_size,1]
        sentences_tokenizer = self.tokenizer(batch_sentences,
                                             truncation=True,
                                             padding=True,
                                             max_length=512,
                                             add_special_tokens=True)
        input_ids=torch.tensor(sentences_tokenizer['input_ids']) # 变量
        attention_mask=torch.tensor(sentences_tokenizer['attention_mask']) # 变量
        bert_out=self.bert(input_ids=input_ids,attention_mask=attention_mask) # 模型

        last_hidden_state =bert_out[0] # [batch_size, sequence_length, hidden_size] # 变量
        bert_cls_hidden_state=last_hidden_state[:,0,:] # 变量
        fc_out=self.fc(bert_cls_hidden_state) # 模型
        return fc_out

# 3. 程序入口,模型也搞完啦,我们可以开始训练,并验证模型的可用性

def main():

    testNumber = 10000    # 多少个数据参与训练模型
    validNumber = 100   # 多少个数据参与验证
    batchsize = 250  # 定义每次放多少个数据参加训练
    
    trainDatas = ImdbDataset(mode="test",testNumber=testNumber) # 加载训练集,全量加载,考虑到我的破机器,先加载个100试试吧
    validDatas = ImdbDataset(mode="valid",validNumber=validNumber) # 加载训练集

    train_loader = torch.utils.data.DataLoader(trainDatas, batch_size=batchsize, shuffle=False)#遍历train_dataloader 每次返回batch_size条数据

    val_loader = torch.utils.data.DataLoader(validDatas, batch_size=batchsize, shuffle=False)

    # 这里搭建训练循环,输出训练结果

    epoch_num = 1  # 设置循环多少次训练,可根据模型计算情况做调整,如果模型陷入了局部最优,那么循环多少次也没啥用

    print('training...(约1 hour(CPU))')
    
    # 初始化模型
    model=BertClassificationModel()
  
    optimizer = optim.AdamW(model.parameters(), lr=5e-5) # 首先定义优化器,这里用的AdamW,lr是学习率,因为bert用的就是这个

    # 这里是定义损失函数,交叉熵损失函数比较常用解决分类问题
    # 依据你解决什么问题,选择什么样的损失函数
    criterion = nn.CrossEntropyLoss()
    
    print("模型数据已经加载完成,现在开始模型训练。")
    for epoch in range(epoch_num):
        for i, (data,labels) in enumerate(train_loader, 0):

            output = model(data[0])
            optimizer.zero_grad()  # 梯度清0
            loss = criterion(output, labels[0])  # 计算误差
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            # 打印一下每一次数据扔进去学习的进展
            print('batch:%d loss:%.5f' % (i, loss.item()))

        # 打印一下每个epoch的深度学习的进展i
        print('epoch:%d loss:%.5f' % (epoch, loss.item()))
    
    #下面开始测试模型是不是好用哈
    print('testing...(约2000秒(CPU))')

    # 这里载入验证模型,他把数据放进去拿输出和输入比较,然后除以总数计算准确率
    # 鉴于这个模型非常简单,就只用了准确率这一个参数,没有考虑混淆矩阵这些
    num = 0
    model.eval()  # 不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化,主要是在测试场景下使用;
    for j, (data,labels) in enumerate(val_loader, 0):

        output = model(data[0])
        # print(output)
        out = output.argmax(dim=1)
        # print(out)
        # print(labels[0])
        num += (out == labels[0]).sum().item()
        # total += len(labels)
    print('Accuracy:', num / validNumber)

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1254815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows的bat文件(学习笔记)

简介 通过windows的cmd执行的批处理&#xff0c;扩展名可以是.bat或.cmd&#xff08;类似linux的shell脚本&#xff09; 所有语句符号不区分大小写 帮助提示信息&#xff1a;命令 /? 1 基本语法 (1) 注释&#xff1a;rem 注释文本不执行 (2) 关闭盘符输出&#xff1a;e…

【软件测试】“我“做了一年的功能点点点测试,感觉在浪费时间...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 发现人们对测试非…

4/5G语音实现方案

今天又是学习充实的一天&#xff0c;今天我们来学习4G和5G语音实现方案的介绍&#xff0c;VOLITE通信流程是怎么样子的&#xff0c;和之前学的TCP协议有什么联系吗&#xff1f;今天我们换个角度来学习通信的流程~ 目录 2G/3G的电话和上网 4G语音实现方案 4G语音的三种方式 …

“BMP转PNG一键转换,批量处理图片,迈入高效图片管理新时代“

你是否曾经为了转换图片格式而烦恼&#xff1f;是否曾经因为一张一张地手动转换而感到无奈&#xff1f;现在&#xff0c;我们的全新工具将为你解决这些问题&#xff0c;开启高效图片管理新时代&#xff01; 首先&#xff0c;我们进入首助编辑高手主页面&#xff0c;会看到有多种…

1、nmap常用命令

文章目录 1. 主机存活探测2. 常见端口扫描、服务版本探测、服务器版本识别3. 全端口&#xff08;TCP/UDP&#xff09;扫描4. 最详细的端口扫描5. 三种TCP扫描方式&#xff08;1&#xff09;TCP connect 扫描&#xff08;2&#xff09;TCP SYN扫描&#xff08;3&#xff09;TCP …

Python自动化测试学习路线【进阶必看】

软件自动化测试的学习步骤 大概步骤如下&#xff1a; 1. 做好手工测试&#xff08;了解各种测试的知识&#xff09;-> 2. 学习编程语言-> 3. 学习Web基础&#xff08;HTML,HTTP,CSS,DOM,Javascript&#xff09;或者 学习Winform -> 4. 学习自动化测试工具 ->5.…

老师组织课外活动的好处有哪些

亲爱的小伙伴们&#xff0c;不知道你们有没有注意到&#xff0c;老师除了在课堂上教学之外&#xff0c;还会在课外组织各种各样的活动呢&#xff1f;这些活动不仅好玩&#xff0c;而且对我们有很多好处哦&#xff01;今天我就来给大家分享一下老师组织课外活动的好处吧&#xf…

目录树自动生成器 golang+fyne

go tree 代码实现请看 gitee 仓库链接 有很多生成目录树的工具&#xff0c;比如windows自带的tree命令&#xff0c;nodejs的treer&#xff0c;tree-cli等等。这些工具都很成熟、很好用&#xff0c;有较完善的功能。 但是&#xff0c;这些工具全部是命令式的&#xff0c;如果…

Java中wait()方法在synchronized方法中调用的奥秘

作为一名Java程序员&#xff0c;我们深知synchronized关键字和wait()方法在多线程编程中的重要性。 在本文中&#xff0c;我们将探讨为什么wait()方法需要在synchronized方法中调用&#xff0c;以及它们是如何协同工作的。 首先&#xff0c;让我们了解一下synchronized关键字和…

嵌入式硬件电路·电平

目录 1. 电平的概念 1.1 高电平 1.2 低电平 2. 电平的使用场景 2.1 高电平使能 2.2 低电平使能 2.3 失能 1. 电平的概念 电平是指电信号电压的大小或高低状态。在数字电子学中&#xff0c;电平有两种状态&#xff0c;高电平和低电平&#xff0c;用来表示二进制中…

代码随想录算法训练营第四十六天|139.单词拆分、背包问题总结

LeetCode 139. 单词拆分 题目链接&#xff1a;139. 单词拆分 - 力扣&#xff08;LeetCode&#xff09; 这道题使用完全背包来实现&#xff0c;我们首先考虑字符串是否可以由字符串列表组成&#xff0c;因此dp数组大小为n 1 &#xff0c;其意义是&#xff0c;在n个位置时是否能…

前缀和+哈希表——525. 连续数组

文章目录 ⛏1. 题目&#x1f5e1;2. 算法原理⚔解法一&#xff1a;暴力枚举⚔解法二&#xff1a;前缀和哈希表 ⚒3. 代码实现 ⛏1. 题目 题目链接&#xff1a;525. 连续数组 - 力扣&#xff08;LeetCode&#xff09; 给定一个二进制数组 nums , 找到含有相同数量的 0 和 1 的最…

超全整理,银行测试-银行项目贷款业务详细,一篇概全...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 银行测试&#xf…

cuda magma 构建 使用cmake构建的步骤记录

这不是群论代数软件&#xff0c;而是cuda 矩阵计算软件 1. 生成其他精度的源代码 1.1 复制编辑 make.inc cp make.inc-examples/make.inc.openblas ./make.inc 并修改其中的定义&#xff1a; OPENBLASDIR ? /opt/OpenBLAS 这需要实现安装openblas到此处。文件夹解构&…

Linux 网络通信

(一)套接字Socket概念 Socket 中文意思是“插座”&#xff0c;在 Linux 环境下&#xff0c;用于表示进程 x 间网络通信的特殊文件 类型。本质为内核借助缓冲区形成的伪文件。 既然是文件&#xff0c;那么理所当然的&#xff0c;我们可以使用文件描述符引用套接字。Linux 系统…

Royal TSX v6.0.1

Royal TSX是一款基于插件的软件&#xff0c;适用于Windows系统&#xff0c;可以用于远程连接和管理服务器。它支持多种连接类型&#xff0c;如RDP、VNC、基于SSH连接的终端&#xff0c;SFTP/FTP/SCP或基于Web的连接管理。 在安装Royal TSX后&#xff0c;需要进行一些基础配置&…

【新手解答2】深入探索 C 语言:一些常见概念的解析

C语言的相关问题解答 写在最前面问题1变量名是否有可能与变量重名&#xff1f;变量名和变量的关系变量名与变量是否会"重名"举例说明结论 变量则是一个地址不变&#xff0c;值时刻在变的“具体数字”变量的地址和值变量名与数据类型具体化示例结论 问题2关于你给我的…

11.8事务

一.Spring实现事务的两种方式 1.通过代码的方式手动实现事务. 2.通过注解的方式实现声明式事务. 二. 1.mysql事务 2. 手动实现事务 3.注解实现事务 使用注解Transactional,可以写在类上或方法上,如果异常,就自动回滚,正常则自动提交. 注意: 如果在代码中添加了try,catch捕…

408—电子笔记分享

一、笔记下载 链接&#xff1a;https://pan.baidu.com/s/1bFz8IX6EkFMWTfY9ozvVpg?pwddeng 提取码&#xff1a;deng b站视频&#xff1a;408-计算机网络-笔记分享_哔哩哔哩_bilibili 包含了408四门科目&#xff08;数据结构、操作系统、计算机组成原理、计算机网络&#xff09…

灭火器二维码巡检卡制作教程

每个消防器材生成独立二维码&#xff0c;取代传统纸质巡检卡&#xff0c;微信扫码巡检&#xff0c;巡检记录汇总后台&#xff0c;随时登录后台查看导出数据&#xff0c;管理人员绑定凡尔码小程序即可随时了解消防巡检完成情况。 生成灭火器巡检码流程图&#xff1a; 1、开通后…