论文复现丨基于ModelArts进行图像风格化绘画

news2024/11/17 6:36:13
摘要:这个 notebook 基于论文「Stylized Neural Painting, arXiv:2011.08114.」提供了最基本的「图片生成绘画」变换的可复现例子。

本文分享自华为云社区《基于ModelArts进行图像风格化绘画》,作者: HWCloudAI 。

项目首页 | GitHub | 论文

ModelArts 项目地址:AI Gallery_Notebook详情_开发者_华为云

下载代码和模型

import os
import moxing as mox
mox.file.copy('obs://obs-aigallery-zc/clf/code/stylized-neural-painting.zip','stylized-neural-painting.zip')
os.system('unzip stylized-neural-painting.zip')
cd stylized-neural-painting
import argparse

import torch
torch.cuda.current_device()
import torch.optim as optim

from painter import *
# 检测运行设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 配置
parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
args = parser.parse_args(args=[])
args.img_path = './test_images/sunflowers.jpg' # 输入图片路径
args.renderer = 'oilpaintbrush' # 渲染器(水彩、马克笔、油画笔刷、矩形) [watercolor, markerpen, oilpaintbrush, rectangle]
args.canvas_color = 'black' # 画布底色 [black, white]
args.canvas_size = 512 # 画布渲染尺寸,单位像素
args.max_m_strokes = 500 # 最大笔划数量
args.m_grid = 5 # 将图片分割为 m_grid x m_grid 的尺寸
args.beta_L1 = 1.0 # L1 loss 权重
args.with_ot_loss = False # 设为 True 以通过 optimal transportation loss 提高收敛。但会降低生成速度
args.beta_ot = 0.1 # optimal transportation loss 权重
args.net_G = 'zou-fusion-net' # 渲染器架构
args.renderer_checkpoint_dir = './checkpoints_G_oilpaintbrush' # 预训练模型路径
args.lr = 0.005 # 笔划搜寻的学习率
args.output_dir = './output' # 输出路径

Download pretrained neural renderer.

Define a helper funtion to check the drawing status.

def _drawing_step_states(pt):
    acc = pt._compute_acc().item()
    print('iteration step %d, G_loss: %.5f, step_psnr: %.5f, strokes: %d / %d'
          % (pt.step_id, pt.G_loss.item(), acc,
              (pt.anchor_id+1)*pt.m_grid*pt.m_grid,
              pt.max_m_strokes))
    vis2 = utils.patches2img(pt.G_final_pred_canvas, pt.m_grid).clip(min=0, max=1)

定义优化循环

def optimize_x(pt):

    pt._load_checkpoint()
    pt.net_G.eval()

    pt.initialize_params()
    pt.x_ctt.requires_grad = True
    pt.x_color.requires_grad = True
    pt.x_alpha.requires_grad = True
    utils.set_requires_grad(pt.net_G, False)

    pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr)

    print('begin to draw...')
    pt.step_id = 0
    for pt.anchor_id in range(0, pt.m_strokes_per_block):
        pt.stroke_sampler(pt.anchor_id)
        iters_per_stroke = 20
        if pt.anchor_id == pt.m_strokes_per_block - 1:
            iters_per_stroke = 40
        for i in range(iters_per_stroke):

            pt.optimizer_x.zero_grad()

            pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
            pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
            pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

            if args.canvas_color == 'white':
                pt.G_pred_canvas = torch.ones([args.m_grid ** 2, 3, 128, 128]).to(device)
            else:
                pt.G_pred_canvas = torch.zeros(args.m_grid ** 2, 3, 128, 128).to(device)

            pt._forward_pass()
            _drawing_step_states(pt)
            pt._backward_x()
            pt.optimizer_x.step()

            pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
            pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
            pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

            pt.step_id += 1

    v = pt.x.detach().cpu().numpy()
    pt._save_stroke_params(v)
    v_n = pt._normalize_strokes(pt.x)
    pt.final_rendered_images = pt._render_on_grids(v_n)
    pt._save_rendered_images()

处理图片,可能需要一些时间,建议使用 32 GB+ 显存

pt = Painter(args=args)
optimize_x(pt)

Check out your results at args.output_dir. Before you download that folder, let’s first have a look at what the generated painting looks like.

plt.subplot(1,2,1)
plt.imshow(pt.img_), plt.title('input')
plt.subplot(1,2,2)
plt.imshow(pt.final_rendered_images[-1]), plt.title('generated')
plt.show()

请下载 args.output_dir 目录到本地查看高分辨率的生成结果/

# 将渲染进度用动交互画形式展现

import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(4,4))
plt.axis('off')
ims = [[plt.imshow(img, animated=True)] for img in pt.final_rendered_images[::10]]

ani = animation.ArtistAnimation(fig, ims, interval=50)

# HTML(ani.to_jshtml())
HTML(ani.to_html5_video())

Next, let’s play style-transfer. Since we frame our stroke prediction under a parameter searching paradigm, our method naturally fits the neural style transfer framework.

接下来,让我们尝试风格迁移,由于我们是在参数搜索范式下构建的笔画预测,因此我们的方法自然的适用于神经风格迁移框架

# 配置
args.content_img_path = './test_images/sunflowers.jpg' # 输入图片的路径(原始的输入图片)
args.style_img_path = './style_images/fire.jpg' # 风格图片路径
args.vector_file = './output/sunflowers_strokes.npz' # 预生成笔划向量文件的路径 
args.transfer_mode = 1 # 风格迁移模式,0:颜色迁移,1:迁移颜色和纹理
args.beta_L1 = 1.0 # L1 loss 权重
args.beta_sty = 0.5 # vgg style loss 权重
args.net_G = 'zou-fusion-net' # 渲染器架构
args.renderer_checkpoint_dir = './checkpoints_G_oilpaintbrush' # 预训练模型路径
args.lr = 0.005 # 笔划搜寻的学习率
args.output_dir = './output' # 输出路径

Again, Let’s define a helper funtion to check the style transfer status.

def _style_transfer_step_states(pt):
      acc = pt._compute_acc().item()
      print('running style transfer... iteration step %d, G_loss: %.5f, step_psnr: %.5f'
            % (pt.step_id, pt.G_loss.item(), acc))
      vis2 = utils.patches2img(pt.G_final_pred_canvas, pt.m_grid).clip(min=0, max=1)

定义优化循环

def optimize_x(pt):

    pt._load_checkpoint()
    pt.net_G.eval()

    if args.transfer_mode == 0: # transfer color only
        pt.x_ctt.requires_grad = False
        pt.x_color.requires_grad = True
        pt.x_alpha.requires_grad = False
    else: # transfer both color and texture
        pt.x_ctt.requires_grad = True
        pt.x_color.requires_grad = True
        pt.x_alpha.requires_grad = True

    pt.optimizer_x_sty = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr)

    iters_per_stroke = 100
    for i in range(iters_per_stroke):
        pt.optimizer_x_sty.zero_grad()

        pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
        pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
        pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

        if args.canvas_color == 'white':
            pt.G_pred_canvas = torch.ones([pt.m_grid*pt.m_grid, 3, 128, 128]).to(device)
        else:
            pt.G_pred_canvas = torch.zeros(pt.m_grid*pt.m_grid, 3, 128, 128).to(device)

        pt._forward_pass()
        _style_transfer_step_states(pt)
        pt._backward_x_sty()
        pt.optimizer_x_sty.step()

        pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
        pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
        pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

        pt.step_id += 1

    print('saving style transfer result...')
    v_n = pt._normalize_strokes(pt.x)
    pt.final_rendered_images = pt._render_on_grids(v_n)

    file_dir = os.path.join(
        args.output_dir, args.content_img_path.split('/')[-1][:-4])
    plt.imsave(file_dir + '_style_img_' +
               args.style_img_path.split('/')[-1][:-4] + '.png', pt.style_img_)
    plt.imsave(file_dir + '_style_transfer_' +
               args.style_img_path.split('/')[-1][:-4] + '.png', pt.final_rendered_images[-1])

运行风格迁移

pt = NeuralStyleTransfer(args=args)
optimize_x(pt)

高分辨率生成文件保存在 args.output_dir

让我们预览一下输出结果:

plt.subplot(1,3,1)
plt.imshow(pt.img_), plt.title('input')
plt.subplot(1,3,2)
plt.imshow(pt.style_img_), plt.title('style')
plt.subplot(1,3,3)
plt.imshow(pt.final_rendered_images[-1]), plt.title('generated')
plt.show()

点击关注,第一时间了解华为云新鲜技术~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/107720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

敏捷、分散式的数据治理,该如何实现?

01 数据资产的生产和消费现状 —— 孤岛就在那里! 在大数据时代,企业数据资产的生产和消费,实际现状大概是这样的。 一方面,每个业务部门都产生并存储了大量的数据。这些数据存储在不同的系统中。每个业务部门都是数据的生产者…

安卓逆向-某音乐软件

初学安卓逆向,如有错误请指教 某我音乐9.3.4.4版本,需要资源的请流言(也可以自行去下载) 直接将apk拖入到AndroidKiller里面(该工具自行下载) 首先去除广告 直接全局搜索KEY_EXTRA_AUTH,可以只在.smali文件里面搜索…

【gitlab wiki】git首次上传本地文档操作步骤

1.在gitLab中创建一个项目 2.进入本地电脑中的你要上传文件的文件目录,右击鼠标选择“ Git Bash Here” git命令窗口(本机电脑要安装好git) 3.在远程git项目中,复制出项目http地址。 4.在“ Git Bash Here” git命令窗口输入命令:git clone h…

Java+MYSQL基于ssm的网上出差审批与费用报销管理系统

全新的时代,新的技术推动着公司管理制度的改革,在管理层面加入了先进的科学技术做到了与时俱进,所以企业创建自己的网上出差审批与费用报销系统是迫切需要的。在新时代的背景下,传统管理方式的缺点被暴露出来,传统管理方式的不足的地方有管理及时性不够,下达一个指令以后需要层…

AuthLab权限在线靶场通关记录

AuthLab通关记录 一个在线的权限靶场:https://authlab.digi.ninja/ 靶场内容比较简单,包括了JWT以及一些基本情况的权限Bypass IP Based Auth Bypass 该关卡根据提示有一个ip在192.168.0.100-200范围里可以直接bypass 拦截请求包修改X-Forwarded-For爆…

python网络副业有哪些?以自身经历分享怎么做副业挣钱

网络副业我个人比较看好的是Python,至少我是真实体会到了Python做副业真香,疫情被关在家那段时间也没耽误赚钱,反而比平常赚的还多一点,下图是我疫情期间在家做Python副业收入的部分截图,那会儿我用Excel表格每天记了一…

BEPU物理引擎碰撞系统的架构与设计

前面我们讲解了如何监听物理引擎的碰撞事件, 在物理引擎内核中如何架构与设计碰撞规则,使得物理Entity与周围的物理环境产生碰撞时,如何灵活的控制物理碰撞,本节給大家详细的讲解BEPUphysicsint 物理引擎内部是如何管理与控制碰撞规则的。本文主要讲解3个…

彻底删除的文件如何恢复?误删数据恢复,四种方法就可以解决

电脑磁盘中存储了许多文件,我们不可避免地会误删一些文件,但是我们中的许多人不知道在文件被错误删除后如何恢复它们。事实上,误删数据恢复没有想象中那么难,我们自己也可以操作完成。到底是什么方法?接下来我们将详细…

Java项目:springboot中小医院信息管理系统

作者主页:源码空间站2022 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 1. 基于SpringBoot的中小医院信息管理系统,实现了部分核心功能。 2. 就诊卡提供了手动和读卡两种方式录入,其中IC读卡器…

Simulink 自动代码生成电机控制:基于霍尔FOC模型和代码生成

目录 霍尔角度估算原理 霍尔角度估算FOC模型和仿真 代码生成软件调试 总结 霍尔角度估算原理 PMSM在定子侧以互差120电角度的位置安装3个霍尔元件Ha, Hb, Hc。当转子转动时,霍尔元件会产生3个相位差120电角度的高低电平信号。霍尔信号会将一个电周期划分为6个扇…

可以快速搭建的免费开源项目:直播带货、富文本笔记、思维导图、声音克隆、消息推送服务、文档协作等等

可以快速搭建的免费开源项目:直播带货、富文本笔记、思维导图、声音克隆、消息推送服务、文档协作等等。 01 Pure Live 一个想让直播回归纯粹的项目,没有礼物、粉丝团、弹窗,只有直播和弹幕。这是国人在GitHub上制作的一个开源的直播系统&am…

一个转行者的自述,大学生做职业规划要趁早

这篇文章写给对自己的职业规划不清晰、想从第一份工作就找准职业方向的应届大学生。 作为21年毕业的职场新人,算不上建议,也不写鸡汤,只是在这里认真分享我走过的弯路。文章略长,预计阅读时间8分钟。 先说一个关于海的小故事 人们…

CSS -- CSS使用过渡(transition)添加动画

CSS 3过渡 过渡(transition)是CSS3中具有颠覆性的特征之一,我们可以在不使用 Flash 动画或JavaScript 的情况下,当元素从一种样式变换为另一种样式时为元素添加效果。 过渡动画: 是从一个状态 渐渐的过渡到另外一个状态 可以让…

js什么是闭包?简单理解

闭包 作用域链和执行上下文 理解闭包前,先引入一个概念,作用域链 用我自己理解的讲:在一段程序中,程序内的变量、函数等都被串在这条链上,当我们使用这些变量、函数时,程序就会在这条链中搜索&#xff0…

【最新】滤器完整性检测各国规定

中国 用于直接接触无菌药液或无菌设备表面的气体的过滤器,应在每批或多批次连续生产结束后对其进行完整性测试。对于其他的应用,可以根据风险评估的结果,制定完整性测试的频率。 ——除菌过滤技术与应用指南 2018 美国 We recommend that …

系统中的安全架构

系统中的安全架构目录概述需求:设计思路实现思路分析1.shiro2.多模块下的安全架构参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for chan…

【网安神器篇】——searchsploit漏洞利用搜索工具

作者名:Demo不是emo 主页面链接:主页传送门 创作初心:舞台再大,你不上台,永远是观众,没人会关心你努不努力,摔的痛不痛,他们只会看你最后站在什么位置,然后羡慕或鄙夷座…

MyBatis是如何初始化的?

摘要:我们知道MyBatis和数据库的交互有两种方式有Java API和Mapper接口两种,所以MyBatis的初始化必然也有两种;那么MyBatis是如何初始化的呢?本文分享自华为云社区《MyBatis详解 - 初始化基本过程》,作者:龙…

golang 协程的实现原理

核心概念 要理解协程的实现, 首先需要了解go中的三个非常重要的概念, 它们分别是G, M和P, 没有看过golang源代码的可能会对它们感到陌生, 这三项是协程最主要的组成部分, 它们在golang的源代码中无处不在. G (goroutine) G是goroutine的头文字, goroutine可以解释为受管理的…

Java+MySQL基于ssm的学生宿舍管理系统

随着我国教育制度的改革,各大高校一直在不断的扩招相对应的学生的数量也在不断的增加。在学生数量增加之后学校后勤人员就需要对后勤部分更加精准的进行管理,其中宿舍管理就是后勤管理中比较重要的一个组成部分。如何能够对学生的宿舍信息进行更加科学合理的管理是当前大多数高…