Python使用AI animegan2-pytorch制作属于你的漫画头像
- 1. 效果图
- 2. 原理
- 3. 源码
- 参考
git clone https://github.com/bryandlee/animegan2-pytorch
cd ./animegan2-pytorch
python test.py --photo_path images/photo_test.jpg --save_path images/animegan2_result.png
1. 效果图
官方效果图如下:
效果图v2 512模型如下:
效果图v1 512模型如下:
效果图v1 效果不太好如下:
效果图rece如下
人物会有一种病态的美,过于白了,风景上效果更好一些;
人物与photo2cartoon的效果图有点像;
效果图paprika 模型如下
人物纹理痕迹太过明显,更适合风景
下一张明兰的效果还不错,不同的模型在不同的图像上也会有些微差别;
origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll 风景效果对比图如下:
origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll 人物效果对比图如下:
2. 原理
人像/风景卡通风格渲染的目标是,在保持原图像 ID 信息和纹理细节的同时,将真实照片转换为卡通风格的非真实感图像。
3. 源码
源码及示例文件模型等见资源:https://download.csdn.net/download/qq_40985985/87739198
# animegan2-pytroch 生成漫画头像或者风景图
# python test.py --checkpoint weights/face_paint_512_v2.pt --input_dir samples/faces/ --device cpu --output_dir samples/resv2
# model loaded: weights/face_paint_512_v2.pt
import os
import argparse
from PIL import Image
import numpy as np
import torch
from torchvision.transforms.functional import to_tensor, to_pil_image
from model import Generator
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def load_image(image_path, x32=False):
img = Image.open(image_path).convert("RGB")
if x32:
def to_32s(x):
return 256 if x < 256 else x - x % 32
w, h = img.size
img = img.resize((to_32s(w), to_32s(h)))
return img
def test(args):
device = args.device
net = Generator()
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
net.to(device).eval()
print(f"model loaded: {args.checkpoint}")
os.makedirs(args.output_dir, exist_ok=True)
for image_name in sorted(os.listdir(args.input_dir)):
if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
continue
image = load_image(os.path.join(args.input_dir, image_name), args.x32)
with torch.no_grad():
image = to_tensor(image).unsqueeze(0) * 2 - 1
out = net(image.to(device), args.upsample_align).cpu()
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
out = to_pil_image(out)
out.save(os.path.join(args.output_dir, image_name))
print(f"image saved: {image_name}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--checkpoint',
type=str,
default='./weights/paprika.pt',
)
parser.add_argument(
'--input_dir',
type=str,
default='./samples/inputs',
)
parser.add_argument(
'--output_dir',
type=str,
default='./samples/results',
)
parser.add_argument(
'--device',
type=str,
default='cuda:0',
)
parser.add_argument(
'--upsample_align',
type=bool,
default=False,
help="Align corners in decoder upsampling layers"
)
parser.add_argument(
'--x32',
action="store_true",
help="Resize images to multiple of 32"
)
args = parser.parse_args()
test(args)
# 原图VS效果图绘制
# python plot_sample.py
# 获取输入路径的所有图像
import cv2
import imutils
import numpy as np
from imutils import paths
imagePaths = sorted(list(paths.list_images("samples")))
list = [x for x in imagePaths if x.find('inputs') > 0]
print(list)
resv1 = [x for x in imagePaths if x.find("resv1") > 0]
resv2 = [x for x in imagePaths if x.find("resv2") > 0]
cele = [x for x in imagePaths if x.find("cele") > 0]
pap = [x for x in imagePaths if x.find("paprika") > 0]
img = None
for i in list:
if (i.find("ml2.jpg") < 0): continue
img = None
for j in resv1:
if (j.split("\\")[2].__eq__(i.split("\\")[2])):
origin = cv2.imread(i)
res = cv2.imread(j)
if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
# print(origin.shape, res.shape)
# print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
imutils.resize(np.hstack([origin, res]), width=300))
if (img is None):
img = imutils.resize(np.hstack([origin, res]), width=300)
else:
imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])
img = imgA
cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
img)
# cv2.waitKey(0)
for j in resv2:
if (j.split("\\")[2].__eq__(i.split("\\")[2])):
origin = cv2.imread(i)
res = cv2.imread(j)
if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
# cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
# imutils.resize(np.hstack([origin, res]), width=300))
if (img is None):
img = imutils.resize(np.hstack([origin, res]), width=300)
else:
imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])
img = imgA
# cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
# img)
# cv2.waitKey(0)
for j in pap:
if (j.split("\\")[2].__eq__(i.split("\\")[2])):
# print('--------------\t', i, j)
origin = cv2.imread(i)
res = cv2.imread(j)
if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
# print(origin.shape, res.shape)
# print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
# cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
# imutils.resize(np.hstack([origin, res]), width=300))
# list.append(imutils.resize(np.hstack([origin, res]), width=300))
if (img is None):
img = imutils.resize(np.hstack([origin, res]), width=300)
else:
imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])
img = imgA
# cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
# img)
# cv2.waitKey(0)
for j in cele:
if (j.split("\\")[2].__eq__(i.split("\\")[2])):
# print('--------------\t', i, j)
origin = cv2.imread(i)
res = cv2.imread(j)
if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
# print(origin.shape, res.shape)
# print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
# cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
# imutils.resize(np.hstack([origin, res]), width=300))
# list.append(imutils.resize(np.hstack([origin, res]), width=300))
if (img is None):
img = imutils.resize(np.hstack([origin, res]), width=300)
else:
imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])
img = imgA
cv2.imshow('origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll',
img)
cv2.waitKey(0)
参考
- https://alltodata.blog.csdn.net/article/details/125183830
- https://github.com/bryandlee/animegan2-pytorch