layoutlmv3训练CDLA数据集

news2024/11/18 3:39:49

一.LayoutLMv3介绍

LayoutLMv3(文档基础模型)
自监督预训练技术在文档人工智能方面取得了显着的进步。大多数多模态预训练模型使用掩码语言建模目标来学习文本模态的双向表示,但它们在图像模态的预训练目标上有所不同。这种差异增加了多模态表示学习的难度。

在本文中,微软提出LayoutLMv3来通过统一的文本和图像掩码来预训练文档 AI 的多模态 Transformer。此外,LayoutLMv3 还使用单词补丁对齐目标进行了预训练,通过预测文本单词的相应图像补丁是否被屏蔽来学习跨模态对齐。简单的统一架构和训练目标使 LayoutLMv3 成为适用于以文本为中心和以图像为中心的文档 AI 任务的通用预训练模型。实验结果表明,LayoutLMv3 不仅在以文本为中心的任务(包括表单理解、收据理解和文档视觉问答)中实现了最先进的性能,而且在以图像为中心的任务(如文档图像分类和文档布局)中也实现了最先进的性能分析。
在这里插入图片描述

二.环境配置

conda create --name layoutlmv3 python=3.7
conda activate layoutlmv3
git clone https://github.com/microsoft/unilm.git
cd unilm/layoutlmv3
pip install -r requirements.txt
# install pytorch, torchvision refer to https://pytorch.org/get-started/locally/
pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
# install detectron2 refer to https://detectron2.readthedocs.io/en/latest/tutorials/install.html
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html
pip install -e .

参考官网下载预训练:
在这里插入图片描述

三.CDLA数据转成coco格式

CDLA数据集介绍:CDLA数据集
在这里插入图片描述
由于CDLA的数据是由Labeme标注的,先将数据集格式转成coco的:

#!/usr/bin/env python

import argparse
import collections
import datetime
import glob
import json
import os
import os.path as osp
import sys
import uuid

import imgviz
import numpy as np

import labelme

try:
    import pycocotools.mask
except ImportError:
    print("Please install pycocotools:\n\n    pip install pycocotools\n")
    sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("input_dir", help="input annotated directory")
    parser.add_argument("output_dir", help="output dataset directory")
    parser.add_argument("--labels", help="labels file", required=True)
    parser.add_argument(
        "--noviz", help="no visualization", action="store_true"
    )
    args = parser.parse_args()

    if osp.exists(args.output_dir):
        print("Output directory already exists:", args.output_dir)
        sys.exit(1)
    os.makedirs(args.output_dir)
    os.makedirs(osp.join(args.output_dir, "JPEGImages"))
    if not args.noviz:
        os.makedirs(osp.join(args.output_dir, "Visualization"))
    print("Creating dataset:", args.output_dir)

    now = datetime.datetime.now()

    data = dict(
        info=dict(
            description=None,
            url=None,
            version=None,
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        licenses=[dict(url=None, id=0, name=None,)],
        images=[
            # license, url, file_name, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    class_name_to_id = {}
    for i, line in enumerate(open(args.labels).readlines()):
        class_id = i - 1  # starts with -1
        class_name = line.strip()
        if class_id == -1:
            assert class_name == "__ignore__"
            continue
        class_name_to_id[class_name] = class_id
        data["categories"].append(
            dict(supercategory=None, id=class_id, name=class_name,)
        )

    out_ann_file = osp.join(args.output_dir, "annotations.json")
    label_files = glob.glob(osp.join(args.input_dir, "*.json"))
    for image_id, filename in enumerate(label_files):
        print("Generating dataset from:", filename)

        label_file = labelme.LabelFile(filename=filename)

        base = osp.splitext(osp.basename(filename))[0]
        out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")

        img = labelme.utils.img_data_to_arr(label_file.imageData)
        imgviz.io.imsave(out_img_file, img)
        data["images"].append(
            dict(
                license=0,
                url=None,
                file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                height=img.shape[0],
                width=img.shape[1],
                date_captured=None,
                id=image_id,
            )
        )

        masks = {}  # for area
        segmentations = collections.defaultdict(list)  # for segmentation
        for shape in label_file.shapes:
            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")
            shape_type = shape.get("shape_type", "polygon")
            mask = labelme.utils.shape_to_mask(
                img.shape[:2], points, shape_type
            )

            if group_id is None:
                group_id = uuid.uuid1()

            instance = (label, group_id)

            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            else:
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        segmentations = dict(segmentations)

        for instance, mask in masks.items():
            cls_name, group_id = instance
            if cls_name not in class_name_to_id:
                continue
            cls_id = class_name_to_id[cls_name]

            mask = np.asfortranarray(mask.astype(np.uint8))
            mask = pycocotools.mask.encode(mask)
            area = float(pycocotools.mask.area(mask))
            bbox = pycocotools.mask.toBbox(mask).flatten().tolist()

            data["annotations"].append(
                dict(
                    id=len(data["annotations"]),
                    image_id=image_id,
                    category_id=cls_id,
                    segmentation=segmentations[instance],
                    area=area,
                    bbox=bbox,
                    iscrowd=0,
                )
            )

        if not args.noviz:
            labels, captions, masks = zip(
                *[
                    (class_name_to_id[cnm], cnm, msk)
                    for (cnm, gid), msk in masks.items()
                    if cnm in class_name_to_id
                ]
            )
            viz = imgviz.instances2rgb(
                image=img,
                labels=labels,
                masks=masks,
                captions=captions,
                font_size=15,
                line_width=2,
            )
            out_viz_file = osp.join(
                args.output_dir, "Visualization", base + ".jpg"
            )
            imgviz.io.imsave(out_viz_file, viz)

    with open(out_ann_file, "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    main()

四.修改配置文件

配置文件路径:/unilm/layoutlmv3/examples/object_detection/cascade_layoutlmv3.yaml
在这里插入图片描述
需要修改的位置:

WEIGHTS: "/microsoft/layoutlmv3-base-chinese/pytorch_model.bin" #预训练权重
NUM_CLASSES: 10   #标签数
IMS_PER_BATCH: 1  #batch_size
CHECKPOINT_PERIOD: 5000 # 每5000个epoch进行一次测试
PUBLAYNET_DATA_DIR_TRAIN: "/layoutlmv3/cdla_data/coco_data/train" #train
PUBLAYNET_DATA_DIR_TEST: "/layoutlmv3/cdla_data/coco_data/val" #val

相关参数的介绍:detectron2(目标检测框架):配置config解析

五.模型训练

train:

python train_net.py --config-file cascade_layoutlmv3.yaml MODEL.WEIGHTS /path/to/microsoft/layoutlmv3-base/pytorch_model.bin OUTPUT_DIR /path/to/layoutlmv3_train/

在这里插入图片描述

在这里插入图片描述

更多参考:
1.微调LayoutLM v3进行票据数据的处理和内容识别
2.yolov8训练CDLA数据文版版面分析
3.layoutlmv3 在中文文档上的应用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1508087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JAVA全面基础知识(第七部分)

大家好我是程序员阿存,混迹在java圈的辛苦码农。今天要和大家聊的是一款,项目源码以及部署相关请联系存哥,文末附上联系信息 。 这篇文章给大家分享的是JAVA的基础知识, 💕💕作者:程序员阿存 &…

qemu快速入门

前提: 我们做嵌入式软件的时候,往往可能会缺少嵌入式的硬件,那我们希望提前开始准备代码的话,就需要qemu这个开源软件,它可以模拟各种型号的芯片 。那么我们可以提前在这个模拟器上面去开发代码、验证、调试。 正片开始…

什么是高级编程语言?——跟老吕学Python编程

什么是高级编程语言?——跟老吕学Python编程 高级编程语言简介高级编程语言发展历程高级编程语言特点高级编程语言分类命令式语言函数式语言逻辑式语言面向对象语言 常见的高级编程语言及其特点和应用领域高级编程语言性能分析高级编程语言的工作方式 高级编程语言简…

深耕版本控制、代码质量与安全等领域,龙智荣获“Perforce 2023年度合作伙伴”奖项

在近日举行的Perforce 2024合作伙伴峰会上,龙智被评选为“Perforce 2023年度合作伙伴”。这一奖项不仅是对龙智在中国市场开拓中的进取精神与丰硕成果的高度认可,也是Perforce公司对于龙智持续创新精神及专业技术与服务的表彰。 自2012年成为Perforce中…

中探:事件循环相关内容(因为不仅仅是初步认识,但也不至于是深入探讨,所以命名为“中探”)

下面内容写于 2022 年,文本描述过多,可能不适合有经验的人看。新的文章在 个人网站 中。 对了,说到事件循环,怎么可以离开这个最知名的视频呢!视频是英文的,但即使你听不懂,单纯看他的操作&…

使用gin框架,编写一个接收数据的api接口

功能:这里主要编写一个接口,将其json 数据存入对应的redis队列中,并统计每天的每小时请求数量 环境: go version go1.22.0 linux/amd64 平台 linux X64 步骤一 新建目录 命令如下: mkdir FormData 步骤二 新增…

【Linux】Linux上的一些软件安装与环境配置(Centos7配置JDK、Hadoop)

文章目录 安装JDK配置环境变量1. 卸载已安装的JDK查询已安装的 jdk 列表删除已经安装的 jdk 2. 上传安装包3. 创建 /usr/local/java 文件夹4. 将 jdk 压缩包解压到 /usr/local/java 目录下5. 配置 jdk 的环境变量6. 让配置文件生效7. 校验8.拍个快照吧,免得后面哪里…

2024 年系统规划与管理师(全套资料)

2024年11月系统规划与管理师全套视频、历年真题及解析、章节分类真题及解析、论文写作及范文、教材、模拟题、答题卡等资料 1、2023年5月、2022年5月、2021年5月、2020年5月四套基础精讲视频,案例分析及论文答题套路视频讲解。 2、系统规划与管理师2017-2023年真题…

【go语言开发】redis简单使用

本文主要介绍redis安装和使用。首先安装redis依赖库,这里是v8版本;然后连接redis,完成基本配置;最后测试封装的工具类 文章目录 安装redis依赖库连接redis和配置工具类封装代码测试 欢迎大家访问个人博客网址:https://…

labview的常用小技巧

1.切换:labview中控件函数与函数选板的使用非常频繁,而使用菜单来调用他们非常不方便。最简单的调用方法是:右击前面板,弹出控件选板;右击程序框图,弹出函数选板。然后按住CtrlE组合键,即可快速…

工业数据采集网关的功能与应用-天拓四方

工业数据采集网关是一种专门用于采集、处理、传输工业现场数据的设备。它能够实时收集来自各种传感器、仪表和设备的数据,并通过网络将这些数据传输到云端或数据中心。同时,数据采集网关还具备数据清洗、转换和压缩等功能,确保数据的质量和传…

MySQL将两条记录根据相同条件合并

知识点:在MySQL中,可以使用GROUP BY子句和聚合函数如CONCAT或CONCAT_WS来将多条记录基于相同条件合并为一条记录 【主要是GROUP_CONCAT这个函数的运用】 例如将员工信息表中相同门店的员工信息合并为一条记录 MySQL语句如下: SELECT dept_…

一文读懂:公网IP地址证书

公网IP证书是一种SSL证书,用于验证和确认特定的公网IP地址是否实际属于申请者。如果验证通过,证书颁发机构将向该IP地址持有人颁发一个以IP地址为主题的SSL证书。使用公网IP证书可以有效提升IP身份的辨识度,减少网站链接被假冒的风险&#xf…

建模杂谈系列237 使用FSM进行状态管理

说明 使用FSM来对状态的变化进行管理,一方面有助于我们将问题定义的更清晰,同时也让程序设计更可靠、可读性(事后)更强。 内容 1 问题描述 假设有一笔投资用于证券交易,随着交易、市场价格变化,投资的状态也随之改变。我们需要…

巫蛊之祸——汉武帝后期的一次重大事件

引 言 “巫蛊之祸”是汉武帝在位后期发生的一次重大政治事件,也是西汉历史上最大的冤案,此案导致皇后卫子夫和太子刘据自杀,数万人头落地,几十万人被牵连。 一、巫蛊之术的由来 《汉书》记载,巫蛊之术起源自胡巫&am…

解决Iterm2升级后遇到“Stashed changes“的问题

<<<<<<< Updated upstream ...... >>>>>>> Stashed changes冲突标记符的代码如题,最近有升级Item2…

ROS的消息发布者与订阅者示例

前言 Topic话题,是节点之间信息交换的方法,在向话题发生送消息的节点叫做发布者,接收消息的节点叫做订阅者。 一个ROS程序中话题可以有很多个,一个话题中也可以有多个发布者和订阅者。一个订阅者可以订阅多个话题。同样 &#xff…

Linux 地址空间

目录 一、程序地址空间 1、虚拟地址 Makefile新写法 2、进程地址空间分布 3、栈&堆 4、static修饰局部变量 5、字符串常量不可修改 6、虚拟地址与物理地址的联系 二、CPU读取程序全过程 1、形成可执行程序 2、生成虚拟地址 3、程序的启动 4、创建进程 5、地…

MyBatis-Plus生成sql语句时怎么知道表名和表的字段名,表的主键名的

MyBatis-Plus通过反射获取实体类的信息。 实体类的类名驼峰转下划线为表名 实体类的属性名驼峰转下划线为字段名 表的主键名默认为id selectById就是基于这个id,select 查询字段 from user where id ? 自定义告诉mybatisplus数据库的表名&#xff0c…

SpringMVC10、拦截器

10、拦截器 10.1、概述 SpringMVC的处理器拦截器类似于Servlet开发中的过滤器Filter,用于对处理器进行预处理和后处理。开发者可以自己定义一些拦截器来实现特定的功能。 过滤器与拦截器的区别:拦截器是AOP思想的具体应用。 过滤器 servlet规范中的一部分&…