DAGA项目 代码阅读笔记1——LSTM-LM部分代码

news2025/1/11 19:47:42

DAGA 代码阅读笔记1——LSTM-LM部分代码

文章目录

  • DAGA 代码阅读笔记1——LSTM-LM部分代码
      • 概述
      • main函数逻辑分析
      • 设置训练参数
      • fields初始化
      • 训练数据读入
      • 模型建立
      • 优化器

概述

​ 学习人工智能的必经之路——读代码。目前阅读的代码来自于github数据增强项目DAGA,这个项目的原论文可以从这里获取。

​ 这个项目主要将标记的句子线性化,然后在线性化数据上训练语言模型(LM),并用于生成合成标记数据,统一了句子生成和使用LM标记的过程。使用该方法,可以有效为序列标记任务生成高质量的合成数据,在低资源条件下,有效提升序列标记模型的性能。

​ 语言模型部分 代码文件如图所示

在这里插入图片描述

main函数逻辑分析

def main():
    """Main workflow"""
    ##### 运行参数读取 #####
    args = utils.build_args(argparse.ArgumentParser())
	##### 从模型文件中读取日志 #####
    utils.init_logger(args.model_file)
	##### 调用gpu训练 #####
    assert torch.cuda.is_available()
    torch.cuda.set_device(args.gpuid)
	##### 确定随机种子 #####
    utils.init_random(args.seed)
	##### 设置训练参数 #####
    utils.set_params(args)
    logger.info("Config:\n%s", pformat(vars(args)))
	##### field初始化 #####
    fields = utils.build_fields()
    logger.info("Fields: %s", fields.keys())
	##### 训练数据读入 #####
    logger.info("Load %s", args.train_file)
    train_data = LMDataset(fields, args.train_file, args.sent_length_trunc)
    logger.info("Training sentences: %d", len(train_data))
    logger.info("Load %s", args.valid_file)
    ##### 测试数据读入 #####
    val_data = LMDataset(fields, args.valid_file, args.sent_length_trunc)
    logger.info("Validation sentences: %d", len(val_data))
	##### 将数据以数值方式存储 #####
    fields["sent"].build_vocab(train_data)
	##### 迭代器 #####
    train_iter = utils.build_dataset_iter(train_data, args)
    val_iter = utils.build_dataset_iter(val_data, args, train=False)
	##### 读取训练断点继续训练 #####
    if args.resume and os.path.isfile(args.checkpoint_file):
        logger.info("Resume training")
        logger.info("Load checkpoint %s", args.checkpoint_file)
        checkpoint = torch.load(
            args.checkpoint_file, map_location=lambda storage, loc: storage
        )
        es_stats = checkpoint["es_stats"]
        args = utils.set_args(args, checkpoint)
    else:
        checkpoint = None
        es_stats = ESStatistics(args)
	##### 模型建立 #####
    model = utils.build_model(fields, args, checkpoint)
    logger.info("Model:\n%s", model)
	##### 优化器 #####
    optimizer = utils.build_optimizer(model, args, checkpoint)
	##### 训练效果输出 #####
    try_train_val(fields, model, optimizer, train_iter, val_iter, es_stats, args)

下面按照main函数中的执行顺序,选择主要代码进行分析

设置训练参数

def set_params(args):
    """Set some params."""
    args.checkpoint_file = "{}.checkpoint".format(args.model_file)
	##### encoder层和decoder层的层数设置 #####
    if args.num_layers != -1:
        args.num_enc_layers = args.num_layers
        args.num_dec_layers = args.num_layers
        logger.info(
            "Set number of encoder/decoder layers uniformly to %d", args.num_layers
        )
	##### 校验encoder层和decoder层合法性 #####
    if args.num_enc_layers < args.num_dec_layers:
        raise RuntimeError("Expected num_enc_layers >= num_dec_layers")
	##### z维输入确认 #####
    if args.num_z_samples == 0:
        args.z_dim = 0
        args.z_cat = False
        args.warmup = 0
	
    args.beta = 1.0 if args.warmup == 0 else 0.0

    args.device = "cuda" if args.gpuid > -1 else "cpu"

fields初始化

def build_fields():
    """Build fields."""
    fields = {}
    fields["sent"] = torchtext.data.Field(
        ##### 规定example数据的句首标记、句尾标记、填充标记 #####
        init_token=BOS_WORD, eos_token=EOS_WORD, pad_token=PAD_WORD
    )
    return fields

训练数据读入

##### 这个文件主要用于读入文本数据 #####
"""Language modeling dataset"""
import io
import torchtext

##### 使用的是torchtext.data中的Dataset结构来储存数据 #####
class LMDataset(torchtext.data.Dataset):
    """Define a dataset class."""
	##### 构造函数,传参fields对象,文件名,限制句子长度 #####
    def __init__(self, fields, filename, truncate=0):
        sents = []
        with io.open(filename, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                line = line.strip().split(" ")
                if truncate:
                    line = line[:truncate]
                sents += [line]
        ##### fields和examples在下方说明 #####
        fields = [(k, fields[k]) for k in fields]
        examples = [torchtext.data.Example.fromlist([sent], fields) for sent in sents]
        super(LMDataset, self).__init__(examples, fields)
	##### 定义排序键:句子长度 #####
    def sort_key(self, ex):
        """Sort by sentence length."""
        return len(ex.sent)

这段代码中比较抽象的部分是torchtext的dataset中有两个变量:fields和examples

examples即torchtext中的example对象构造的列表,而example就是对数据集中一条数据的抽象

fields即torchtext中的field对象构造的列表,field对象可以理解为数据表中的列标题,其定义了列数据的处理形式

TorchText使用一个声明式的方法来加载数据:你可以告诉TorchText你想要的数据类型,它会根据声明处理数据。这一方式是通过 声明那个一个Field对象来实现的。Field就是你定义的数据处理形式。

下面使用样例代码直观了解field和example这两个类

##### 对field对象的参数进行设置,后续处理数据时会按照设置的形式对数据进行处理 #####
TEXT = Field(sequential=True, tokenize=tokenize, lower=True)
LABEL = Field(sequential=False, use_vocab=False)

def get_dataset(csv_data, text_field, label_field, test=False):
    fields = [('id', None), ('comment_text', text_field), ('toxic', label_field)]
    examples = []
    if test:
        for text in tqdm(csv_data['comment_text']):
            examples.append(data.Example.fromlist([None, text, None], fields))
    else:
        for text, label in tqdm(zip(csv_data['comment_text'], csv_data['toxic'])):
            examples.append(data.Example.fromlist([None, text, label], fields))
    return examples, fields

train_examples, train_fields = get_dataset(train_data, TEXT, LABEL)
valid_examples, valid_fields = get_dataset(valid_data, TEXT, LABEL)
test_examples, test_fields = get_dataset(test_data, TEXT, None, True)

train = data.Dataset(train_examples, train_fields)
valid = data.Dataset(valid_examples, valid_fields)
test = data.Dataset(test_examples, test_fields)

模型建立

def build_model(fields, args, checkpoint=None):
    """Build model."""
    ##### 具体模型建立过程在下一篇学习笔记中分析 #####
    model = LMModel(fields, args)
    if checkpoint is not None:
        logger.info("Set model using saved checkpoint")
        model.load_state_dict(checkpoint["model"])
    return model.to(args.device)

优化器

def build_optimizer(model, args, checkpoint=None):
    """Build optimizer."""
    params = [p for p in model.parameters() if p.requires_grad]
    n_params = sum([p.nelement() for p in params])
    logger.info("Trainable parameters: %d", n_params)
	##### 优化器定义两种方法,SGD和Adam #####
    method = {"sgd": torch.optim.SGD, "adam": torch.optim.Adam}
    optimizer = method[args.optim](params, lr=args.lr)
    logger.info("Use %s with lr %f", args.optim, args.lr)
	##### 保存点optimizer参数读取 #####
    if checkpoint is not None:
        logger.info("Set optimizer states using saved checkpoint")
        optimizer.load_state_dict(checkpoint["optimizer"])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(args.device)
    return optimizer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392029.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

socket 编程实战(编写客户端程序 )

编写客户端程序 接着上一篇&#xff1a;实战服务端程序 接下来我们再编写一个简单地客户端应用程序&#xff0c;客户端的功能是连接上小节所实现的服务器&#xff0c;连接成功之后向服务器发送数据&#xff0c;发送的数据由用户输入。示例代码如下所示&#xff1a; #include…

大数据开发治理平台 DataWorks

序言学习下阿里DataWorks的设计理念以及要做的事情cuiyaonan2000163.com参考文档:https://www.aliyun.com/product/bigdata/idehttps://help.aliyun.com/document_detail/73015.htmlhttps://help.aliyun.com/document_detail/324149.html ----数据治理LaunchDataWorks基于阿里云…

ECCV 2022|面向精确的主动相机定位算法

标题&#xff1a;ECCV 2022,山东大学、北大、腾讯AILab、斯坦福和三维家联合提出&#xff0c;面向精确的主动相机定位算法项目地址&#xff1a;https://github.com/qhFang/AccurateACL.文章&#xff1a;Towards Accurate Active Camera Localization&#xff08;ECCV 2022&…

oneblog_justauth_三方登录配置【Gitee】

文章目录oneblog添加第三方平台gitee中创建三方应用完善信息oneblog添加第三方平台 1.oneblog管理端&#xff0c;点击左侧菜单 网站管理——>社会化登录配置管理 ,添加一个社会化登录 2.编辑信息如下&#xff0c;选择gitee平台后复制redirectUri,然后去gitee获取clientId和…

vulnhub five86-2

总结&#xff1a;sudo -l&#xff0c;抓流量包&#xff0c;搜索引擎。。 目录 下载地址 漏洞分析 信息收集 网站渗透 ​编辑 反弹shell提权 下载地址 Five86-2.zip (Size: 1.7 GB)Download (Mirror): https://download.vulnhub.com/five86/Five86-2.zip使用&#xff1a;下…

电子台账:模板制作之一——列过滤(水平过滤)

1 简介列过滤即水平过滤。一般情况下&#xff0c;企业数据源文件中有很多数据列&#xff0c;其中大部分数据列中的数据对电子台账来说是没有用的。列过滤就是确定企业数据文件的哪几列有用&#xff0c;以及有用的列分别对应到台账&#xff08;模板&#xff09;的哪一列。列过滤…

FreeSql使用

目的: 1.方库分表 2.主从分离 3.分布式事务 过程&#xff1a; 官网&#xff1a;指南 | FreeSql 官方文档 1.Startup.cs 添加配置&#xff08;本地数据库MySql&#xff09; ConfigureServices&#xff1a; Func<IServiceProvider, IFreeSql> fsql r >{IFreeSql …

吉利银河L7、长城哈弗B07、比亚迪宋Plus DM-i,自主品牌决战混动

2月23日&#xff0c;吉利推出全新的中高端新能源产品序列——吉利银河。当日&#xff0c;吉利推出了首款智能电混SUV「银河L7」&#xff0c;新车将在二季度交付。本月10日&#xff0c;长城汽车也计划举办智能新能源干货大会&#xff0c;其「颠覆技术」等宣传直面新一代的新能源…

SUDO(CVE-2021-3156复现)提权rsync未授权访问提权

一、SUDO(CVE-2021-3156复现 判断漏洞存在&#xff1a; 1.版本 sudo: 1.8.2 - 1.8.31p2 sudo: 1.9.0 - 1.9.5p1 2.报错存在次漏洞 sudoedit -s / 不是报错信息&#xff1a; 复现&#xff1a; 环境&#xff1a;docker的centos7 需要新建一个用户 docker pull chenaot…

k8s学习之路 | Day20 k8s 工作负载 Deployment(下)

文章目录3. HPA 动态扩缩容3.1 HPA3.2 安装 metrics-server3.3 验证指标收集3.4 扩缩容的实现3.5 增加负载3.6 降低负载3.7 更多的度量指标4. 金丝雀部署4.1 蓝绿部署4.2 金丝雀部署4.3 金丝雀部署的实现5. Deployment 状态与排查5.1 进行中的 Deployment5.2 完成的 Deployment…

wordpress更新文章后总是向文章内连接发送GET请求

通过观察wordpress请求发现&#xff0c;wordpress在更新文章后会向文章发送GET请求。在发送请求之前会执行一个调用定时的一个请求POST /wp-cron.php?doing_wp_cron1678081385.6844499111175537109375 HTTP/1.1执行这个定时后&#xff0c;这篇文章的所有链接都会发送HEAD和GET…

源码阅读笔记 InputFormat、FileInputFormat、CombineTextInputFormat

1. InputFormat InputFormat是MapReduce框架提供的用来处理job输入的基类 它主要定义了三个功能&#xff1a; 1.验证job输入是否合法 2.对输入文件进行逻辑切片(InputSplit)&#xff0c;然后将每个切片分发给单独的MapTask 3.提供切片读取器(Re…

Java的注解(Annotation)

Java 注解&#xff08;Annotation&#xff09;又称 Java 标注&#xff0c;是 JDK5.0 引入的一种注释机制。Java 中的类、构造器、方法、成员变量、参数等都可以被注解进行标注。例如JUnit单元测试中的Test方法&#xff0c;可以使得方法直接运行。JUnit单元测试Test单元测试是针…

2023年湖北助理工程师在哪里申报?助理工程师的五大作用你知道吗

2023年湖北助理工程师在哪里申报&#xff1f;助理工程师的五大作用你知道吗 助理工程师申报条件&#xff1a; 大学本科毕业&#xff1a;毕业满一年&#xff0c;工科类专业&#xff0c;6个月以上社保证明 大学专科毕业&#xff1a;毕业满三年&#xff0c;工科类专业&#xff0…

贝塞尔曲线与B样条曲线

文章目录0.参考1.问题起源与插值法的曲线拟合1.1.问题起源1.2.拉格朗日插值1.3.“基”的概念1.4.插值存在的Runge现象2.贝塞尔曲线2.1.控制点的思想2.2.由控制点生成贝塞尔曲线2.3.多个控制点时的贝塞尔曲线公式2.4.贝塞尔曲线的递推公式2.5.贝塞尔曲线的性质3.B样条曲线3.1.B样…

项目设计原则

单一设计原则 做过管理系统项目的同学肯定都接触过用户、机构、角色管理这些模块&#xff0c;实现方式都是基于RBAC模型&#xff08;Role-Based Access Control&#xff0c;基于角色的访问控制&#xff0c;通过分配和取消角色来完成用户权限的授予和取消&#xff0c;使动作主体…

web开发 用idea创建一个新项目

这个写着就是给自己当备忘录用的QAQ 这个老师上课一通操作啥也没看清…卑微搞了半天看样子是成功了 记录一下省的以后忘了怎么创建&#xff08;&#xff1f; zufe lxy 2023.3 先行条件是已经自己装好了Tomcat和idea&#xff01;&#xff01;&#xff08;我的idea是申请了教育…

MSDP实验配置

目录 配置MSDP 配置PIM SM协议 配置各PIM SM域内的静态RP 配置MSDP对等体 配置域内的MSDP对等体 AR8和AR9建立EBGP邻居 配置域间的MSDP对等体 进行实验验证 什么是MSDP MSDP&#xff08;Multicast Source Discovery Protocol&#xff09;组播源发现协议的简称 用来传递…

帆船结构3D线上展示教学的亮点有哪些?

由广州华锐互动开发的帆船结构3D线上展示教学系统&#xff0c;是一种创新的教学方式&#xff0c;基于虚拟现实技术&#xff0c;通过3D模型、交互式模拟等技术手段&#xff0c;可以让学生在虚拟环境中进行帆船组装和调试训练&#xff0c;以达到实践教学的目的。不同于传统的实践…

Python绘图

1.二维绘图 a. 一维数据集 用 Numpy ndarray 作为数据传入 ply 1. import numpy as np import matplotlib as mpl import matplotlib.pyplot as pltnp.random.seed(1000) y np.random.standard_normal(10) print "y %s"% y x range(len(y)) print "x%s&q…