LaMa 论文复现:Resolution-robust Large Mask Inpainting with Fourier Convolutions

news2024/10/6 6:05:50

 代码:GitHub - andy971022/auto-lama 

论文:https://arxiv.org/abs/2109.07161

1 LaMa 论文简介

2 LaMa代码复现

2.1 环境部署

 2.1.1 下载源码,创建环境,安装必需库

git clone https://github.com/advimman/lama
cd lama
conda env create -f conda_env.yml
conda activate lama
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -y
pip install pytorch-lightning==1.2.9

2.2  公开数据集训练测试与结果可视化

2.2.1 LaMa 测试数据集和预训练模型下载

(1)预训练模型下载链接:

预训练模型 https://disk.yandex.ru/d/kHJkc7bs7mKIVAicon-default.png?t=N7T8https://disk.yandex.ru/d/kHJkc7bs7mKIVA

 预训练模型下载好后,存放在checkpoints文件夹下。

(2)测试数据集下载:

    # Download data from http://places2.csail.mit.edu/download.html
    # Places365-Standard: Train(105GB)/Test(19GB)/Val(2.1GB) from High-resolution images section
    wget http://data.csail.mit.edu/places/places365/train_large_places365standard.tar
    wget http://data.csail.mit.edu/places/places365/val_large.tar
    wget http://data.csail.mit.edu/places/places365/test_large.tar

http://data.csail.mit.edu/places/places365/val_large.tar
http://data.csail.mit.edu/places/places365/test_large.tar

2.2.2  place365 数据集训练

2.2.3  place365 数据集测试

预测性能,基于big-lama数据集中的LaMa_test_images。

运行以下命令,其中refine=true 表示将运行图像修复器。

(nerf) D:\0A_project\lama-main\bin> python predict.py refine=True model.path=$(pwd)../checkpoint/big-lama indir=$(pwd)../LaMa_test_images outdir=$(pwd)../output

model.path=$(pwd)/big-lama: 这部分是传递给predict.py脚本的命令行参数之一。它设置了一个参数model.path,并将其值设置为当前目录(通过$(pwd)获取)下的big-lama

indir=$(pwd)/LaMa_test_images: 这是另一个命令行参数,用于设置输入目录。它将indir参数的值设置为当前目录下的LaMa_test_images目录。

outdir=$(pwd)/output: 类似地,这是设置输出目录的参数。它将outdir参数的值设置为当前目录下的output目录。

出错如下:

Traceback (most recent call last):
  File "predict.py", line 24, in <module>
    from  saicinpainting.evaluation.utils import move_to_device
ModuleNotFoundError: No module named 'saicinpainting'

代码段引用模块包内容如下:

import logging
import os
import traceback
import sys
from saicinpainting.evaluation.refinement import refine_predict
from saicinpainting.evaluation.utils import move_to_device

文件结构如下,saicinpainting模块包位于lama-main 主文件夹下,predict.py位于bin文件夹中。 

       

因此,出现 ModuleNotFoundError: No module named 'saicinpainting'  错误是该包没有在搜索路径中找到,故需要把该路径添加到搜索路径中,代码更改如下:

import logging
import os
import traceback
import sys
sys.path.append(r'D:\0A_project\lama-main')  # 添加项目根目录到 sys.path
from saicinpainting.evaluation.refinement import refine_predict
from saicinpainting.evaluation.utils import move_to_device

再次运行,又报错(待解决)

(1)不要用GPU预测,尝试无法解决

(2)python predict.py refine=True model.path=$(pwd)../checkpoint/big-lama indir=$(pwd)../LaMa_test_images outdir=$(pwd)../output  HYDRA_FULL_ERROR=1   无法解决

(3)D:\0A_project\lama-main\configs\prediction\default.yaml   添加  HYDRA_FULL_ERROR=1   无法解决

(4)注释掉predict.py line 41    

# register_debug_signal_handlers()  # kill -10 <pid> will result in traceback dumped into log

(2)在predict.py line45 后面加上一句

train_config_path = os.path.join('<your_full_path_to_lama_base_directory>', train_config_path)

如下: 

在lin54 句后面加上
checkpoint_path = os.path.join('<your_full_path_to_lama_base_directory>', checkpoint_path)

2.2.4  测试结果和参数可视化

2.3 制作自己的数据集,训练测试与结果的可视化

2.3.1 制作自己的数据集

(1)创建数据集图片对应的mask图,命名为images_name_maskxxx.png, 将images原图与对应的masks原图放在同一文件夹下。数据集文件格式如下:
    ```    
    image1_mask001.png
    image1.png
    image2_mask001.png
    image2.png
    ```
(2)利用(https://github.com/advimman/lama/blob/main/bin/gen_mask_dataset.py) 生成随机的mask图片。

将自己图像的数据集存放在myown_dataset文件夹下面。

将configs/prediction/default.yaml 文件中的`image_suffix` 声明为png或jpg或_input.jpg,如下
indir: no  # 将在CLI中被覆盖
outdir: no  # 将在CLI中被覆盖

model:
  path: no  # 将在CLI中被覆盖
  checkpoint: best.ckpt

dataset:
  kind: default
  img_suffix: .png
  pad_out_to_modulo: 8  # 输出图像将被填充到8的倍数

device: cuda  # 使用CUDA设备
out_key: inpainted  # 输出键:inpainted

refine: False  # 如果为True,将运行图像修复器
refiner:
  gpu_ids: 0,1  # 使用的GPU编号。如果只使用单个GPU,使用:"0,"
  modulo: ${dataset.pad_out_to_modulo}  # 与数据集的填充模数一致
  n_iters: 15  # 每个尺度的迭代修复次数
  lr: 0.002  # 学习率
  min_side: 512  # 所有尺度的图像边缘都应 >= min_side / sqrt(2)
  max_scales: 3  # 图像-掩码金字塔的最大降尺度数量
  px_budget: 1800000  # 像素预算。任何图像都将调整大小以满足高*宽 <= px_budget
运行命令
python3 bin/gen_mask_dataset.py indir=$(pwd)/myown_dataset outdir=$(pwd)/myown_dataset   

gen_mask_dataset.py解读如下
#!/usr/bin/env python3

import glob  # 用于查找文件
import os  # 提供文件和目录操作的功能
import shutil  # 用于文件复制和移动
import traceback  # 用于处理异常信息

import PIL.Image as Image  # 用于处理图像的Python库
import numpy as np  # 用于数值计算的Python库
from joblib import Parallel, delayed  # 用于并行处理任务的库

from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop  # 导入特定的图像处理功能
from saicinpainting.evaluation.utils import load_yaml, SmallMode  # 导入加载YAML配置和小模式处理的功能
from saicinpainting.training.data.masks import MixedMaskGenerator  # 导入混合掩码生成器

# 创建一个包装器,用于生成多个掩码变体
class MakeManyMasksWrapper:
    def __init__(self, impl, variants_n=2):
        self.impl = impl
        self.variants_n = variants_n

    def get_masks(self, img):
        img = np.transpose(np.array(img), (2, 0, 1))
        return [self.impl(img)[0] for _ in range(self.variants_n)]

# 处理图像
def process_images(src_images, indir, outdir, config):
    # 根据配置选择掩码生成器
    if config.generator_kind == 'segmentation':
        mask_generator = SegmentationMask(**config.mask_generator_kwargs)
    elif config.generator_kind == 'random':
        variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
        mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
                                              variants_n=variants_n)
    else:
        raise ValueError(f'Unexpected generator kind: {config.generator_kind}')

    max_tamper_area = config.get('max_tamper_area', 1)

    for infile in src_images:
        try:
            # 获取文件相对路径
            file_relpath = infile[len(indir):]
            img_outpath = os.path.join(outdir, file_relpath)
            os.makedirs(os.path.dirname(img_outpath), exist_ok=True)

            # 打开输入图像并转换为RGB格式
            image = Image.open(infile).convert('RGB')

            # 将输入图像缩放到输出分辨率,并过滤小图像
            if min(image.size) < config.cropping.out_min_size:
                handle_small_mode = SmallMode(config.cropping.handle_small_mode)
                if handle_small_mode == SmallMode.DROP:
                    continue
                elif handle_small_mode == SmallMode.UPSCALE:
                    factor = config.cropping.out_min_size / min(image.size)
                    out_size = (np.array(image.size) * factor).round().astype('uint32')
                    image = image.resize(out_size, resample=Image.BICUBIC)
            else:
                factor = config.cropping.out_min_size / min(image.size)
                out_size = (np.array(image.size) * factor).round().astype('uint32')
                image = image.resize(out_size, resample=Image.BICUBIC)

            # 生成和选择掩码
            src_masks = mask_generator.get_masks(image)

            filtered_image_mask_pairs = []
            for cur_mask in src_masks:
                if config.cropping.out_square_crop:
                    (crop_left,
                     crop_top,
                     crop_right,
                     crop_bottom) = propose_random_square_crop(cur_mask,
                                                               min_overlap=config.cropping.crop_min_overlap)
                    cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
                    cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
                else:
                    cur_image = image

                if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
                    continue

                filtered_image_mask_pairs.append((cur_image, cur_mask))

            mask_indices = np.random.choice(len(filtered_image_mask_pairs),
                                            size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
                                            replace=False)

            # 剪裁掩码并保存掩码和输入图像
            mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
            for i, idx in enumerate(mask_indices):
                cur_image, cur_mask = filtered_image_mask_pairs[idx]
                cur_basename = mask_basename + f'_crop{i:03d}'
                Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
                                mode='L').save(cur_basename + f'_mask{i:03d}.png')
                cur_image.save(cur_basename + '.png')
        except KeyboardInterrupt:
            return
        except Exception as ex:
            print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')

# 主函数
def main(args):
    if not args.indir.endswith('/'):
        args.indir += '/'

    os.makedirs(args.outdir, exist_ok=True)

    config = load_yaml(args.config)

    in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
    if args.n_jobs == 0:
        process_images(in_files, args.indir, args.outdir, config)
    else:
        in_files_n = len(in_files)
        chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
        Parallel(n_jobs=args.n_jobs)(
            delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
            for start in range(0, len(in_files), chunk_size)
        )

# 如果这个脚本被直接执行
if __name__ == '__main__':
    import argparse

    aparser = argparse.ArgumentParser()
    aparser.add_argument('config', type=str, help='Path to config for dataset generation')
    aparser.add_argument('indir', type=str, help='Path to folder with images')
    aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
    aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
    aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')

    main(aparser.parse_args())
    用于处理图像并生成掩码。它包含了一些配置选项、图像处理功能以及处理多个图像的能力。主要的功能包括处理输入图像,生成掩码,剪裁图像和掩码,然后将它们保存到指定的输出目录。这个脚本还支持多进程处理,可以加快处理大量图像。

在上述代码中,"掩码" 是指一个二值图像,通常表示了一些区域的存在或缺失。数学形式表示掩码通常是一个矩阵(或图像),其中每个元素可以是二进制值(0或1),表示相应位置是否包含某种特征或信息。

具体地,如果我们考虑一个二维掩码矩阵,其中每个元素 (i, j) 的值为 1 表示该位置被覆盖或包含信息,值为 0 表示该位置没有信息或被遮挡。掩码通常用于图像处理和计算机视觉任务中,用于标识感兴趣的区域或对象。

例如,一个简单的数学形式的表示可以是:

  • 对于一个 2D 图像,M(i, j) 表示掩码矩阵中的元素,其中 (i, j) 是矩阵的坐标,M(i, j) 的值为 1 表示该位置包含信息,M(i, j) 的值为 0 表示该位置不包含信息。

掩码通常用于图像分割、遮挡区域检测、图像处理等任务,以便识别和操作图像中的感兴趣区域。在代码中,掩码用二维数组(NumPy数组)来表示,其中元素的值为0或1,这样可以方便地进行图像处理操作。

2.3.2  训练自己的数据集

python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

2.3.3  测试自己的数据集  

python3 bin/predict.py model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

2.3.4  测试结果参数及可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1189183.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JAVASSMmysql面向高校校园体育用品租借管理系统94593-计算机毕业设计项目选题推荐(附源码)

摘 要 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;微信小程序的面向高校校园体育用品租借管理系统被用户普遍…

Git->git简介,git的常用命令,git命令的常用理论

git简介git的常用命令git命令的常用理论 1.git简介 Git是什么&#xff1f; Git是一个开源的分布式&#xff0c;用于敏捷高效地处理任何或小或大的项目 Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。 Git 与常用的版本控制工具 CVSI…

在任何机器人上实施 ROS 导航堆栈的指南

文章目录 路径规划参考 路径规划 路径规划是导航的最终目标。这允许用户向机器人给出目标姿势&#xff0c;并让它在给定的环境中自主地从当前位置导航到目标位置。这是我们迄今为止所做的一切&#xff08;地图绘制和本地化&#xff09;的汇集点。ROS 导航堆栈已经为我们完成了…

教培管理系统源码 教育培训机构系统源码 教务系统源码

教培管理系统源码 教育培训机构系统源码 教务系统源码 功能介绍&#xff1a; 教务中心: 学员管理 班级管理 课表管理 教师管理 课程/收费 上课记录 家校互动: 课后作业 课后点评 成绩单 成绩档案 通知管理 营销中心&#xff1a; 活动模板 我的活动 销售中心&am…

双十一数码好物推荐,盘点那些错过等一年的好物!

双十一购物狂欢节马上到来&#xff0c;对于热爱数码产品的人来说&#xff0c;双十一无疑是一个绝佳的时机&#xff0c;因为许多知名品牌和零售商都会推出各种令人心动的数码好物促销活动。从佩戴服饰到大件智能装备&#xff0c;再到健康科技产品&#xff0c;市场上的选择多种多…

竞赛 身份证识别系统 - 图像识别 深度学习

文章目录 0 前言1 实现方法1.1 原理1.1.1 字符定位1.1.2 字符识别1.1.3 深度学习算法介绍1.1.4 模型选择 2 算法流程3 部分关键代码 4 效果展示5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 毕业设计 图像识别 深度学习 身份证识别…

四川芸鹰蓬飞商务信息咨询有限公司电商服务引领潮流

在今天的数字时代&#xff0c;抖音带货已成为一种新型的、高效的营销方式。许多公司都在寻找可靠的抖音带货服务&#xff0c;以扩大其品牌影响力并增加销售额。在这方面&#xff0c;四川芸鹰蓬飞商务信息咨询有限公司以其专业的知识和经验&#xff0c;成为行业内的佼佼者。 四…

画家尹星,美术界的扫地僧

尹星 简历&#xff1a; 1944年1月25日出生于山西省阳高县&#xff0c;内蒙古师范学院艺术系美术专业&#xff0c;师从水彩之父李剑晨&#xff0c;北京京华美术学院创立者邱石冥&#xff0c;徐坚。与吴冠中&#xff0c;朱德群&#xff0c;赵无极&#xff0c;杨飞云是同门。擅长…

【EI会议征稿】JPCS独立出版-第五届新材料与清洁能源国际学术会议(ICAMCE 2024)

JPCS独立出版-第五届新材料与清洁能源国际学术会议&#xff08;ICAMCE 2024&#xff09; 2024 5th International Conference on Advanced Material and Clean Energy 第五届新材料与清洁能源国际学术会议&#xff08;ICAMCE 2024&#xff09;将于2024年2月23-25日在中国▪长沙…

采集Prestashop独立站采集Prestashop独立站

import java.net.URL 这一行导入了Java.net包中的URL类&#xff0c;这个类在处理URL链接时非常有用。 import org.jsoup.Jsoup 这一行导入了Jsoup库&#xff0c;它是一个强大的HTML和XML文档解析库&#xff0c;我们可以使用它来解析网页内容。 import org.jsoup.nodes.Docume…

安卓数据恢复工具哪个强? 10 个最佳 Android 数据恢复应用程序

如果您是 Android 用户并且已经使用您的设备一段时间&#xff0c;那么您很可能遇到过与数据相关的问题。这可能是由于软件问题导致文件被意外删除或损坏。许多人不经常备份数据&#xff0c;从而丢失了重要的文档、图像、视频文件等。最糟糕的是&#xff0c;数据丢失可能随时发生…

AI智能雷达名片平台版小程序源码系统 带完整的搭建教程

大家好啊&#xff0c;今天源码小编来给大家分享一款AI智能雷达名片平台版小程序源码系统。人工智能技术的不断发展和普及&#xff0c;越来越多的企业开始应用AI技术来提高业务效率和提升用户体验。AI智能雷达名片平台版小程序源码系统就是利用人工智能技术&#xff0c;帮助企业…

WPS的JS宏基础

一、基础知识 1、简单的第一个宏 //注意function只能全部用小写 function demo(){alert("你好!") }2、录制宏生成工资条 function 使用录制宏自动生成代码以JS宏为例()//使用相对引用 {Selection.Copy(undefined);ActiveCell.Offset(5, 0).Range("A1:M4"…

基于springboot实现福聚苑社区团购平台系统项目【项目源码】

基于springboot实现福聚苑社区团购平台系统演示 Javar技术 Java是一种网络脚本语言&#xff0c;广泛运用于web应用开发&#xff0c;可以用来添加网页的格式动态效果&#xff0c;该语言不用进行预编译就直接运行&#xff0c;可以直接嵌入HTML语言中&#xff0c;写成js语言&…

智慧油气推动能源行业的绿色转型和可持续发展

智慧油气推动能源行业的绿色转型和可持续发展 随着技术的不断进步和创新的推动&#xff0c;智慧油气正成为引领能源行业发展的重要趋势。通过融合物联网、云计算、人工智能等先进技术&#xff0c;智慧油气实现了油气资源的高效管理和利用&#xff0c;为能源行业带来了巨大的变革…

Spring Cloud智慧工地管理平台源码,智慧工地APP源码,实现对劳务人员、施工进度、工地安全、材料设备、环境监测等方面的实时监控和管理

智慧工地管理平台源码&#xff0c;智慧工地APP源码&#xff0c; 智慧工地管理平台实现对人员管理、施工进度、安全管理、材料管理、设备管理、环境监测等方面的实时监控和管理&#xff0c;提高施工效率和质量&#xff0c;降低安全风险和环境污染。智慧工地平台支持项目级、公司…

STM32-EXTI中断

EXTI简介 EXTI&#xff08;Extern Interrupt&#xff09;外部中断 EXTI可以监测指定GPIO口的电平信号&#xff0c;当其指定的GPIO口产生电平变化时&#xff0c;EXTI将立即向NVIC发出中断申请&#xff0c;经过NVIC裁决后即可中断CPU主程序&#xff0c;使CPU执行EXTI对应的中断程…

站在创新视角理解美的集团“全球突破”

全球化&#xff0c;对于企业发展的意义毋庸赘言。 作为一家年营收3000多亿的科技集团&#xff0c;美的集团有超过四成收入来自海外市场。 可以预见的是&#xff0c;未来海外市场的重要性还会不断提升。因为国内家电市场正在从增量周期转入存量周期&#xff0c;市场增长趋稳。…

《开箱元宇宙》:认识香港麦当劳通过 The Sandbox McNuggets Land 的 Web3 成功经验

McNuggets Land 是 The Sandbox 于 2023 年发布的最受欢迎的体验之一。在本期的《开箱元宇宙》系列中&#xff0c;我们采访了香港麦当劳数位顾客体验暨合作伙伴资深总监 Kai Tsang&#xff0c;来了解这一成功案例背后的策略。 在不断发展的市场营销和品牌推广领域&#xff0c;不…

每条价格仅1美分,美国军人敏感信息正被低价售卖

杜克大学于11月6日发布的的一项新研究报告表明&#xff0c;网络攻击者可以轻松地从数据经纪人手中&#xff0c;以低廉的价格获取有关美国军人的敏感信息。 数据经纪人收集和汇总信息&#xff0c;然后直接或通过利用数据的服务出售、许可或共享信息。数据经纪人包括 Equifax 和 …