文章目录
- 概要
- 权重下载
- 光流估计
- 结果预览
概要
RIFE,一种用于视频帧插值(VFI)的实时中间流估计算法。许多最近基于流动的VFI方法首先估计双向光流,然后将它们缩放和反转为近似的中间流,从而导致运动边界和复杂管道上的伪影。RIFE使用一个名为IFNet的神经网络,该神经网络可以直接估计从粗到细的中间流量,速度要快得多。我们设计了一种用于训练 IFNet 的特权蒸馏方案,从而大大提高了性能。RIFE 不依赖于预训练的光流模型,并且可以支持使用时间编码输入进行任意时间步长帧插值。实验表明,RIFE 在几个公共基准测试中实现了最先进的性能。与流行的 SuperSlomo 和 DAIN 方法相比,RIFE 快 4–27 倍,产生更好的结果。该代码可在 https://github.com/hzwer/arXiv2020-RIFE 上找到。
权重下载
https://drive.google.com/file/d/147XVsDXBfJPlyct2jfo9kpbL944mNeZr/view?usp=sharing
光流估计
import torch
from torch.nn import functional as F
from model.RIFE import Model
import warnings
warnings.filterwarnings("ignore")
import argparse
import cv2
from utils.flow_viz import flow_to_image
import matplotlib.pyplot as plt
import os
import glob
import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def viz(img1, img2, flo, outPath):
# img1 = img1[0].permute(1,2,0).cpu().numpy()
# img2 = img2[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_to_image(flo)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 4))
ax1.set_title('input image1')
ax1.imshow(img1.astype(int))
ax2.set_title('input image2')
ax2.imshow(img2.astype(int))
ax3.set_title('estimated optical flow')
ax3.imshow(flo)
# plt.show()
plt.savefig(outPath, bbox_inches='tight') # 'optical_flow_comparison.png' 可以替换为您想要的文件名
plt.close()
if __name__ == "__main__":
model = Model(arbitrary=True)
model.load_model('RIFE_m_train_log')
model.eval()
model.device()
images = glob.glob(os.path.join("../dataset", '*.png')) + \
glob.glob(os.path.join("../dataset", '*.jpg'))
images = sorted(images)
for i, (imfile1, imfile2) in tqdm.tqdm(enumerate(zip(images[:-1], images[1:]))):
img0_ = cv2.imread(imfile1, cv2.IMREAD_UNCHANGED)
img1_ = cv2.imread(imfile2, cv2.IMREAD_UNCHANGED)
img0 = torch.from_numpy(img0_.copy()).permute(2, 0, 1) / 255.0
img1 = torch.from_numpy(img1_.copy()).permute(2, 0, 1) / 255.0
img = torch.cat((img0, img1), 0).to(torch.float).unsqueeze(0).cuda()
n, c, h, w = img.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
img = F.pad(img, padding)
# print(f"img size {img.size()}")
with torch.no_grad():
flow = model.flownet(img, timestep=1.0, returnflow=True)[:, :2] # will get flow1->0
# flow = flow[0].permute(1,2,0).cpu().numpy()
print(f"flow size : {flow.size()}")
viz(img0_, img1_, flow, f"{i}_.png")