CV | Emotionally Enhanced Talking Face Generation论文详解及代码实现

news2024/12/28 5:08:14

 本博客主要讲解了Emotionally Enhanced Talking Face Generation(情感增强的谈话人脸生成)论文概括与项目实现,以及代码理解。

Emotionally Enhanced Talking Face Generation

Paper :https://arxiv.org/pdf/2303.11548.pdf

Code: GitHub - sahilg06/EmoGen: PyTorch Implementation for Paper "Emotionally Enhanced Talking Face Generation"

(克隆项目下载权重后,可直接进行推理)

目录

论文概括

项目实现

1.环境设置

2.数据处理及项目运行

3.开始训练

3.1.训练专家鉴别器

3.2.训练情绪鉴别器

3.3.训练最终模型

4.推理

过程中遇到的问题及解决(PS)

代码详解(按运行顺序)


论文概括

论文创新点

  • 输入视频,任意人脸+情绪合成
  • 提出了一个新的深度学习模型,可以生成照片般逼真的唇语人脸视频,其中包含了不同的情绪和相关表情。
  • 引入了一个多模态框架,以生成与任何任意身份、语言和情感无关的唇语视频。
  • 开发了一个基于网络的响应式界面,用于实时生成带有情绪的对话脸。

模型框架

 

项目实现

1.环境设置

Ubuntu(docker 容器) ,torch-gpu,cuda11.7

克隆项目,

git clone https://github.com/sahilg06/EmoGen
cd EmoGen

 安装相关库

sudo apt-get install ffmpeg
pip install -r requirements.txt
#相关库
pip install albumentations

 配置下载命令工具,安装git lfs

curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
apt-get install git-lfs

git lfs install

下载CREMA-D数据集命令 (使用git clone会出错)

git lfs clone https://github.com/CheyneyComputerScience/CREMA-D.git

2.数据处理及项目运行

运行

python convertFPS.py -i /workspace/facegan/EmoGen/CREMA-D/VideoFlash -o /workspace/facegan/EmoGen/CREMA-D/flv-output

将视频文件(.flv)每25帧截图保存为flv-output文件夹。(文件夹名自己设置)

如果出错可参考【PS3】

 

 (等大概25分钟左右)

 接着处理MP4文件,截取数据集的人脸部分,且每个人,每个人的心情等划分为不同的文件夹:

运行

python preprocess_crema-d.py --data_root /workspace/facegan/EmoGen/CREMA-D/flv-output --preprocessed_root preprocessed_dataset/

 如果出错可参考【PS4/5】,运行过程如图

 (大约运行3小时)

3.开始训练

主要分为三个步骤:

  1. 训练专家口型同步鉴别器
  2. 训练情绪鉴别器 
  3. 训练 EmoGen 模型

3.1.训练专家鉴别器

(如果只有一个gpu,或者内存不大需要修改相关代码)

相关参数在hparams.py中,可修改batch_size和num_workers(默认为batch_size=16,num_workers=16)

python color_syncnet_train.py --data_root preprocessed_dataset/ --checkpoint_dir sync-checkpoint

 开始训练

运行16小时后epoch210,保存sync-chrakpoint文件夹下

 这里产生的权重文件是SyncNet网络+残差跳跃连接 训练后的权重,

因为设置epoch数比较高,可按下快捷键Ctrl+C 停止训练。

3.2.训练情绪鉴别器

python emotion_disc_train.py -i preprocessed_dataset/ -o emo-checkpoint

 开始训练

 因为设置epoch数比较高,可按下快捷键Ctrl+C 停止训练。

3.3.训练最终模型

python train.py --data_root preprocessed_dataset/ --checkpoint_dir emogen-checkpoint --syncnet_checkpoint_path sync-checkpoint/checkpoint_step000011256.pth --emotion_disc_path emo-checkpoint/disc_emo_23000.pth

- emogen-checkpoint 训练后的权重文件都保存在这个文件夹 

- sync-checkpoint/checkpoint_step000011256.pth 训练鉴别器时的权重文件

- emo-checkpoint/disc_emo_23000.pth训练情绪鉴别器时的权重文件

4.推理

*推理时需要注释掉以下代码:

model/wav2lip.py中的108行和113行,改为

        #emotion = emotion.unsqueeze(1).repeat(1, 5, 1) #(B, T, 6) 

 #emotion = torch.cat([emotion[:, i] for i in range(emotion.size(1))], dim=0) #(B*T, 6)

 然后进行推理:

python inference.py --checkpoint_path emogen-checkpoint/训练后的checkpoint.pth --face 自己的mp4文件 --audio 一个语音文件  --emotion 想要生成的情绪选择
--checkpoint_path 
--face 自己的mp4文件 
--audio 一个语音文件 : *.wav*.mp3甚至是视频文件,代码会自动从中提取音频
--emotion 想要生成的情绪选择 :从列表中选择分类情绪:[HAP、SAD、FEA、ANG、DIS、NEU]

中间省略

在推理时,要在wav2lip模型中,把语音编码(audio_embedding)和情绪编码(emotion_embedding)进行连接(torch.cat)

也可以不进行训练,直接下载checkpoint,下载地址:

python inference.py --checkpoint_path checkpoint.pth --face temp.mp4 --audio temp.wav  --emotion HAP

过程中遇到的问题及解决(PS)

[PS1]docker容器安装ffmpeg失败,出现Err:1 http://security.ubuntu.com/ubuntu focal-updates/main amd64 libwebp6 amd64 0.6.1-2ubuntu0.20.04.1
  404  Not Found [IP: 185.125.190.39 80]
Err:2 http://security.ubuntu.com/ubuntu focal-updates/main amd64 libwebpmux3 amd64 0.6.1-2ubuntu0.20.04.1
  404  Not Found [IP: 185.125.190.39 80]

 原因分析

linux服务器上ffmpeg版本为4.2.7,且没问题

docker容器安装辅助项

apt-get install yasm
apt-get install libx264-dev
apt-get install libfdk-aac-dev
apt-get install libmp3lame-dev
apt-get install libopus-dev
apt-get install libvpx-dev

apt-get update
apt install ffmpeg

成功后

数据集Flash 样本

【PS2】TypeError: makedirs() got an unexpected keyword argument 'exist_ok'

解决方法:删掉exist_ok=True

【PS3】/workspace/facegan/EmoGen/CREMA-D/VideoFlash/1018_MTI_DIS_XX.flv: Invalid data found when processing input

解决方法:是下载数据时,文件出现问题,重新下载数据后正常。

【PS4】AttributeError: partially initialized module 'cv2' has no attribute 'gapi_wip_gst_GStreamerPipeline' (most likely due to a circular import)

查看opencv_python的版本,是4.7.0.72

把版本降级

pip install opencv-python==4.3.0.36

 【PS5】ImportError: libSM.so.6: cannot open shared object file: No such file or directory

 因为我是docker容器,所以要下载

pip install opencv-python-headless==4.3.0.36

 [PS6]RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 64 but got size 320 for tensor number 1 in the list.

代码详解(按运行顺序)

convertFPS.py

import argparse
import os
import subprocess

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    #定义参数 -i 是 --input-folder的缩写,运行时添加数据集的路径
    parser.add_argument("-i", "--input-folder", type=str, help='Path to folder that contains video files')
    #定义参数 -fps 是指画面每秒传输帧数,即动画或视频每秒切换的图片张数,帧数越大,流畅度越高,
    parser.add_argument("-fps", type=float, help='Target FPS', default=25.0)
    #定义参数 -o 是 --output-folder的缩写,运行时添加数据处理后的路径
    parser.add_argument("-o", "--output-folder", type=str, help='Path to output folder')
    args = parser.parse_args()

     #建立处理后的文件夹名
    #os.makedirs(args.output_folder, exist_ok=True)
    os.makedirs(args.output_folder)
    fileList = []
    #对于数据路径下的(文件格式为:MP4、mpg、mov、flv)的文件,循环切割提取文件名 
    for root, dirnames, filenames in os.walk(args.input_folder):
        for filename in filenames:
            if os.path.splitext(filename)[1] == '.mp4' or os.path.splitext(filename)[1] == '.mpg' or os.path.splitext(filename)[1] == '.mov' or os.path.splitext(filename)[1] == '.flv':
                 #对于所提取的文件进行展平
                fileList.append(os.path.join(root, filename))

    #对于所提取的文件利用ffmpeg库进行视频切片,并存为.MP4文件
     
    for file in fileList:
        subprocess.run("ffmpeg -i {} -r 25 -y {}".format(file, os.path.splitext(file.replace(args.input_folder, args.output_folder))[0]+".mp4"), shell=True)

使用ffmpeg分割视频时,指定开始、结束时间。使用以下命令

ffmpeg -ss [start] -i [input] -to [end] -c copy [output]

参数

参数作用
-ss读取位置
-iffmpeg的必要字段
-t持续时间
-to结束位置
-c编解码器
copy源文件编解码器
[start]开始时间
[end]结束时间
[duration]持续时间
[input]输入文件路径
[output]输出文件路径

-r : 每秒帧数(指定帧率,这样达到视频压缩效果)
注意 :-ss 要放在 -i 之前
preprocess_crema-d.py

import sys

if sys.version_info[0] < 3 and sys.version_info[1] < 2:
	raise Exception("Must be using >= Python 3.2")

from os import listdir, path

if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
	raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
							before running this script!')

import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import argparse, os, cv2, traceback, subprocess
from tqdm import tqdm
from glob import glob
import audio
from hparams import hparams as hp

import face_detection

parser = argparse.ArgumentParser()
# gpu的数量
parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
# 单一gpu人脸检测的批量大小,默认32
parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
# 数据集地址
parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
# 处理后的数据集地址
parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)

args = parser.parse_args()
#识别视频数据集中的人脸
fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, 
									device='cuda:{}'.format(id)) for id in range(args.ngpu)]

#识别人脸后利用ffmpeg库处理
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'

def process_video_file(vfile, args, gpu_id):
	video_stream = cv2.VideoCapture(vfile)
	
	frames = []
	while 1:
		still_reading, frame = video_stream.read()
		if not still_reading:
			video_stream.release()
			break
		frames.append(frame)
	
	vidname = os.path.basename(vfile).split('.')[0]
	#dirname = vfile.split('/')[-2]

	fulldir = path.join(args.preprocessed_root, vidname)
	os.makedirs(fulldir, exist_ok=True)

	batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]

	i = -1
	for fb in batches:
		preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))

		for j, f in enumerate(preds):
			i += 1
			if f is None:
				continue
            # 截取的人脸保存四点为一张照片
			x1, y1, x2, y2 = f
			cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])

def process_audio_file(vfile, args):
	vidname = os.path.basename(vfile).split('.')[0]
	#dirname = vfile.split('/')[-2]

	fulldir = path.join(args.preprocessed_root, vidname)
	os.makedirs(fulldir, exist_ok=True)

	wavpath = path.join(fulldir, 'audio.wav')

	command = template.format(vfile, wavpath)
	subprocess.call(command, shell=True)

	
def mp_handler(job):
	vfile, args, gpu_id = job
	try:
		process_video_file(vfile, args, gpu_id)
	except KeyboardInterrupt:
		exit(0)
	except:
		traceback.print_exc()
		
def main(args):
	print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
    
    #
	filelist = glob(path.join(args.data_root, '*.mp4'))

	jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
	p = ThreadPoolExecutor(args.ngpu)
	futures = [p.submit(mp_handler, j) for j in jobs]
	_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]

	print('Dumping audios...')

	for vfile in tqdm(filelist):
		try:
			process_audio_file(vfile, args)
		except KeyboardInterrupt:
			exit(0)
		except:
			traceback.print_exc()
			continue

if __name__ == '__main__':
	main(args)

 color_syncnet_train.py

from os.path import dirname, join, basename, isfile, isdir
from tqdm import tqdm

from models import SyncNet_color as SyncNet
import audio

import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np

from glob import glob

import os, random, cv2, argparse
import albumentations as A
from hparams import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')

# 数据集路径
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)

parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)

args = parser.parse_args()


global_step = 0
global_epoch = 0
os.environ['CUDA_VISIBLE_DEVICES']='2'
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
emonet_T = 5
syncnet_mel_step_size = 16

class Dataset(object):
    def __init__(self, split):
        #self.all_videos = get_image_list(args.data_root, split)
        self.all_videos = [join(args.data_root, f) for f in os.listdir(args.data_root) if isdir(join(args.data_root, f))]
        print('Num files: ', len(self.all_videos))

        # to apply same augmentation for all the frames
        target = {}
        for i in range(1, emonet_T):
            target['image' + str(i)] = 'image'
        
        self.augments = A.Compose([
                        A.RandomBrightnessContrast(p=0.2),    
                        A.RandomGamma(p=0.2),    
                        A.CLAHE(p=0.2),
                        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=50, val_shift_limit=50, p=0.2),  
                        A.ChannelShuffle(p=0.2), 
                        A.RGBShift(p=0.2),
                        A.RandomBrightness(p=0.2),
                        A.RandomContrast(p=0.2),
                        A.GaussNoise(var_limit=(10.0, 50.0), p=0.25),
                    ], additional_targets=target, p=0.8)
    
    def augmentVideo(self, video):
        args = {}
        args['image'] = video[0, :, :, :]
        for i in range(1, emonet_T):
            args['image' + str(i)] = video[i, :, :, :]
        result = self.augments(**args)
        video[0, :, :, :] = result['image']
        for i in range(1, emonet_T):
            video[i, :, :, :] = result['image' + str(i)]
        return video

    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def crop_audio_window(self, spec, start_frame):
        # num_frames = (T x hop_size * fps) / sample_rate
        start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))

        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx : end_idx, :]


    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]
            #print(vidname)

            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            if random.choice([True, False]):
                y = torch.ones(1).float()
                chosen = img_name
            else:
                y = torch.zeros(1).float()
                chosen = wrong_img_name

            window_fnames = self.get_window(chosen)
            if window_fnames is None:
                continue

            window = []
            all_read = True
            for fname in window_fnames:
                img = cv2.imread(fname)
                if img is None:
                    all_read = False
                    break
                try:
                    img = cv2.resize(img, (hparams.img_size, hparams.img_size))
                except Exception as e:
                    all_read = False
                    break

                window.append(img)

            if not all_read: continue

            try:
                wavpath = join(vidname, "audio.wav")
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)

            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            # H x W x 3 * T
            window = np.asarray(window)
            aug_results = self.augmentVideo(window)
            window = np.split(aug_results, syncnet_T, axis=0)

            x = np.concatenate(window, axis=3) / 255.
            x = np.squeeze(x, axis=0).transpose(2, 0, 1)
            # print(x.shape)
            x = x[:, x.shape[1]//2:]

            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)

            return x, mel, y

logloss = nn.BCELoss()
def cosine_loss(a, v, y):
    d = nn.functional.cosine_similarity(a, v)
    loss = logloss(d.unsqueeze(1), y)

    return loss

def train(device, model, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=6)

    global global_step, global_epoch
    resumed_step = global_step
    num_batches = len(train_data_loader)
    
    while global_epoch < nepochs:
        print('Epoch: {}'.format(global_epoch))
        running_loss = 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, mel, y) in prog_bar:
            model.train()
            optimizer.zero_grad()
            
            # Transform data to CUDA device
            x = x.to(device)
            mel = mel.to(device)

            a, v = model(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            loss.backward()
            optimizer.step()

            global_step += 1
            cur_session_steps = global_step - resumed_step
            running_loss += loss.item()

            # if global_step == 1 or global_step % checkpoint_interval == 0:
            #     save_checkpoint(
            #         model, optimizer, global_step, checkpoint_dir, global_epoch)

            # if global_step % hparams.syncnet_eval_interval == 0:
            #     with torch.no_grad():
            #         eval_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)

            prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))

        writer.add_scalar("Loss/train", running_loss/num_batches, global_epoch)

        with torch.no_grad():
            eval_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
            if(global_epoch % 50 == 0):
                save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch)

        global_epoch += 1

def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
    eval_steps = 1400
    print('Evaluating for {} steps'.format(eval_steps))
    losses = []
    while 1:
        for step, (x, mel, y) in enumerate(test_data_loader):

            model.eval()

            # Transform data to CUDA device
            x = x.to(device)

            mel = mel.to(device)

            a, v = model(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            losses.append(loss.item())

            if step > eval_steps: break

        averaged_loss = sum(losses) / len(losses)
        print(averaged_loss)
        writer.add_scalar("Loss/val", averaged_loss, global_step)

        return averaged_loss

def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):

    checkpoint_path = join(
        checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
    optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

def load_checkpoint(path, model, optimizer, reset_optimizer=False):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    model.load_state_dict(checkpoint["state_dict"])
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    global_step = checkpoint["global_step"]
    global_epoch = checkpoint["global_epoch"]

    return model

if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    checkpoint_path = args.checkpoint_path

    if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)

    # Dataset and Dataloader setup
    #train_dataset = Dataset('train')
    #test_dataset = Dataset('val')

    full_dataset = Dataset('train')
    train_size = int(0.95 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.syncnet_batch_size,
        num_workers=8)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = SyncNet().to(device)
    #model = nn.DataParallel(SyncNet(), device_ids=[1,2]).to(device)

    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.syncnet_lr,betas=(0.5,0.999))

    if checkpoint_path is not None:
        load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)

    writer = SummaryWriter('runs/crema-d_disc_exp2_data_aug')

    train(device, model, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=checkpoint_dir,
          checkpoint_interval=hparams.syncnet_checkpoint_interval,
          nepochs=hparams.nepochs)

    writer.flush()

以Syncnet网络为基础,训练一个鉴别器,关于Syncnet,详细可查看

emotion_disc_train.py

import argparse
import json
import os
from tqdm import tqdm
import random as rn
import shutil

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from torch.utils.tensorboard import SummaryWriter

from models import emo_disc
from datagen_aug import Dataset

def initParams():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("-i", "--in-path", type=str, help="Input folder containing train data", default=None, required=True)
    # parser.add_argument("-v", "--val-path", type=str, help="Input folder containing validation data", default=None, required=True)
    parser.add_argument("-o", "--out-path", type=str, help="output folder", default='../models/def', required=True)

    parser.add_argument('--num_epochs', type=int, default=10000)
    parser.add_argument("--batch-size", type=int, default=64)

    parser.add_argument('--lr_emo', type=float, default=1e-06)

    parser.add_argument("--gpu-no", type=str, help="select gpu", default='1')
    parser.add_argument('--seed', type=int, default=9)

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_no

    args.batch_size = args.batch_size * max(int(torch.cuda.device_count()), 1)
    args.steplr = 200

    args.filters = [64, 128, 256, 512, 512]
    #-----------------------------------------#
    #           Reproducible results          #
    #-----------------------------------------#
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    rn.seed(args.seed)
    torch.manual_seed(args.seed)
    #-----------------------------------------#
   
    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)
    else:
        shutil.rmtree(args.out_path)
        os.mkdir(args.out_path)

    with open(os.path.join(args.out_path, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    args.cuda = torch.cuda.is_available() 
    print('Cuda device available: ', args.cuda)
    args.device = torch.device("cuda" if args.cuda else "cpu") 
    args.kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}

    return args

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d or type(m) == nn.Conv1d:
        torch.nn.init.xavier_uniform_(m.weight)

def enableGrad(model, requires_grad):
    for p in model.parameters():
        p.requires_grad_(requires_grad)
  

def train():
    args = initParams()
    
    trainDset = Dataset(args)

    train_loader = torch.utils.data.DataLoader(trainDset,
                                               batch_size=args.batch_size, 
                                               shuffle=True,
                                               drop_last=True,
                                               **args.kwargs)
    
    device_ids = list(range(torch.cuda.device_count()))
    
    disc_emo = emo_disc.DISCEMO().to(args.device)
    disc_emo.apply(init_weights)
    #disc_emo = nn.DataParallel(disc_emo, device_ids)

    emo_loss_disc = nn.CrossEntropyLoss()

    num_batches = len(train_loader)
    print(args.batch_size, num_batches)

    global_step = 0
    
    for epoch in range(args.num_epochs):
        print('Epoch: {}'.format(epoch))
        prog_bar = tqdm(enumerate(train_loader))
        running_loss = 0.
        for step, (x, y) in prog_bar:
            video, emotion = x.to(args.device), y.to(args.device)

            disc_emo.train()

            disc_emo.opt.zero_grad() # .module is because of nn.DataParallel 

            class_real = disc_emo(video)

            loss = emo_loss_disc(class_real, torch.argmax(emotion, dim=1))

            running_loss += loss.item()

            loss.backward()
            disc_emo.opt.step() # .module is because of nn.DataParallel 
            
            #每隔1000打印并保存权重文件
            if global_step % 1000 == 0:
                print('Saving the network')
                torch.save(disc_emo.state_dict(), os.path.join(args.out_path, f'disc_emo_{global_step}.pth'))
                print('Network has been saved')
            
            prog_bar.set_description('classification Loss: {}'.format(running_loss / (step + 1)))

            global_step += 1

        writer.add_scalar("classification Loss", running_loss/num_batches, epoch)
        
        disc_emo.scheduler.step() # .module is because of nn.DataParallel 

if __name__ == "__main__":

    writer = SummaryWriter('runs/emo_disc_exp4')
    train()

下载CREMA-D数据集命令

git clone https://github.com/CheyneyComputerScience/CREMA-D
#文件会出错,MP3文件,wav文件,flv文件等克隆后为二进制文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/610063.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS:服务数据(srv)的定义与使用

目录 一、服务模型二、创建功能包三、自定义服务数据3.1定义srv文件3.2在package.xml中添加功能包依赖3.3在CMakeLists.txt中添加编译选项3.4编译生成语言相关文件 四、创建代码并编译运行&#xff08;C&#xff09;4.1创建代码4.2编译4.3运行 一、服务模型 Client发布显示某个…

价值8800元SEO自动化养权重流量站课程分享(升级版)!

本来想做培训收8800&#xff0c;但是我怕大伙骂我&#xff08;说我割韭菜&#xff09;&#xff0c;所以我决定免费把这套自动化批量养站的技术和流程详细给大家分享出来。有些朋友可能是手动养&#xff0c;我觉得这种思路是没错的&#xff0c;但是有点鸡肋&#xff0c;先说下缺…

电子科技大学计算机系统结构复习笔记(三):流水线技术

目录 前言 重点一览 流水线定义 基本概念 流水线分类 流水线特点 流水线时空图 流水线性能分析 流水线特点 经典5段流水线RISC处理器 流水线的三种冒险 冒险分类 停顿流水线 结构冒险 数据冒险 控制冒险 流水线处理机的指令系统 流水线指令系统与格式 流水…

nvm安装并配置环境变量使用nvm安装、切换nodejs

目录 第一章 准备工作 1.1 卸载nodejs 1.2 安装nvm 第二章 nvm环境配置 第三章 nodejs安装以及环境配置 3.1 会用nvm常用命令 3.2 nodejs安装 3.3 node环境配置 3.4 遇到的问题 第一章 准备工作 1.1 卸载nodejs 找到自己对应的nodejs文件所在路径 where node 通过控…

Python 异常类型捕获( try ... except 用法浅析)——Don‘t bare except (不要让 except 裸奔)

不要让 except 裸奔&#xff01;裸奔很爽&#xff0c;但有隐忧。 (本笔记适合学完 Python 五大基本数据类型&#xff0c;有了些 Python 基础的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.org/ Free&#xff1a;大咖免费“圣经”教程…

大模型时代的来临:AI如何改变人类生活和经济?

大模型时代的来临&#xff1a;AI如何改变人类生活和经济&#xff1f; 第三次AI浪潮之下&#xff0c;人类面临着前所未有的机遇和挑战。随着人工智能的快速发展&#xff0c;我们开始高度重视其可能带来的负面影响。 最近&#xff0c;AI领域再次引起了全球范围内的关注&#xff0…

机器视觉_HALCON_图像采集接口编程手册_1.第一章节介绍

文章目录 一、前言二、图像采集接口编程第一章2.1 HALCON的通用图像采集接口2.2 图像采集基础2.3 同步抓取 vs. 异步抓取⭐2.4 缓冲策略⭐2.5 A/D转换和多路复用2.6 HALCON图像采集算子⭐2.6.1 open_framgrabber2.6.2 close_framegrabber2.6.3 info_framegrabber2.6.4 grab_ima…

chatgpt赋能python:Python均值函数介绍

Python均值函数介绍 Python是一种高级编程语言&#xff0c;非常适合数据处理和分析。在数据分析中&#xff0c;均值通常被用来代表一组数据的平均水平。Python提供了多种方式来计算均值&#xff0c;其中最常用的是使用均值函数来计算。 什么是均值函数&#xff1f; 均值函数…

高通 Camera HAL3:添加一个VendorTag

一.概述 MetadataTag在CamX中有两种体现&#xff0c;可以是预定义的AndroidTag或是自定义VendorTag VendorTag在HAL中定义&#xff0c;用来支持Camx和Chi所需的额外metadata VendorTag类型有三种&#xff1a; hwVendorTagInfocomponentvendortaginfocoreVendorTagInfo 根据不…

「HTML和CSS入门指南」img 标签详解

什么是 img 标签? 在 HTML 中,img 标签用于插入图像。它是一个独立的标签,没有结束标记,并且可以设置多种属性来改变图片的大小、位置、样式等。使用 img 标记可以帮助您更好地展示您的内容,并让浏览器更快地加载网页。 img 标签的基本语法 以下是 img 标签的基本语法: …

卡尔曼滤波与组合导航原理笔记(一)卡尔曼滤波方程的推导 第二部分

文章目录 三、卡尔曼滤波1、随机系统状态空间模型2、状态预测3、状态量测4、增益矩阵K与状态估计5、Kalman滤波公式汇总6、Kalman滤波流程图1.划分为左右两部分&#xff08;一阶矩和二阶矩&#xff09;2.划分为上下两部分&#xff08;时间更新、量测更新&#xff09; 7、Kalman…

ESP8266开发阶段无线WIFI本地烧录升级 -- FOTA

【本文发布于https://blog.csdn.net/Stack_/article/details/130448713&#xff0c;未经允许不得转载&#xff0c;转载须注明出处】 前言 因为正在DIY一个WiFi计量插座&#xff0c;采用非隔离的方案&#xff0c;烧录时要拔掉220V插头&#xff0c;测试时要拔掉USB线&#xff0c;…

php获取文件的权限信息(获取权限信息、返回字符串涵义、二进制的转换方式、权限修改)

php获取文件的权限信息 说明1.获取文件的权限信息2.返回文件权限字符的解读3.转为二进制权限4.修改权限 说明 &#xff08;图片来源于网络&#xff09; 文件权限是指文件或目录对用户和其他进程的访问许可。在 Unix 和 Linux 系统中&#xff0c;文件和目录都有三个权限&#x…

高通 Camera HAL3:CamX、Chi-CDK 详解

网上关于高通CameraHAL3的介绍文档不多&#xff0c;之前做高通CameraHAL3的一些收集、总结&#xff0c;杂乱了一点&#xff0c;将就着看吧。 一.初步认知 高通CameraHAL3的架构很庞大&#xff0c;代码量也很巨大。 先对CamX、Chi-CDK的关键术语、目录等有个初步认知 1.1 术…

Servlet与Mybatis-2

过滤器 过滤器是一种代码重用的技术&#xff0c;它可以改变 HTTP 请求的内容&#xff0c;响应&#xff0c;及 header 信息。过滤器通常不产生响应或像 servlet 那样对请求作出响应&#xff0c;而是修改或调整到资源的请求&#xff0c;修改或调整来自资源的响应。 作用&#x…

Linux基础篇 使用SSH远程Ubuntu-03

目录 1.安装ssh服务器 2.启用SSH服务器 3.查看SSH服务运行状态 4.在Windows的CMD下进行验证 在默认情况下&#xff0c;外部设备是无法通过SSH远程Ubuntu的&#xff0c;因为Ubuntu没有启用ssh服务。 说明&#xff1a;当前Ubuntu系统为20.04 1.安装ssh服务器 sudo apt-get …

chatgpt赋能python:Python在一组数据中抽取数的方法

Python在一组数据中抽取数的方法 Python是一种非常流行的编程语言&#xff0c;因为它简单易学&#xff0c;可读性高&#xff0c;功能强大&#xff0c;适用于各种不同的应用场景。在数据科学领域&#xff0c;Python也非常受欢迎&#xff0c;因为它拥有广泛的数据处理和分析库。…

【Go LeetDay】总目录(1~88)

Leetcode Golang Day1~10 Golang每日一练(leetDay0001) 1. 两数之和 Two Sum 2. 两数相加 Add Two Numbers 3. 无重复字符的最长子串 Longest-substring-without-repeating-characters Golang每日一练(leetDay0002) 4. 寻找两个正序数组的中位数 Median of two sorted arra…

使用RP2040自制的树莓派pico—— [1/100] 烧录micropython固件

目录 开发环境烧录模式简介固件下载固件烧录验证阶段micropython初步了解 开发环境 软件&#xff1a;Thonny 烧录固件&#xff1a;micropython 烧录模式简介 正常插上电就启动&#xff0c;这是树莓派pico开发板的正常启动模式。 如果按住 bootset 按键再插上数据线&#xf…

Vue 设计模式

一、什么是设计模式&#xff1f; 设计模式是一套被反复使用、多数人知晓、经过分类编目的、代码设计经验的总结。它是为了可重用代码&#xff0c;让代码更容易的被他人理解并保证代码的可靠性。 设计模式实际上是“拿来主义”在软件领域的贯彻实践&#xff0c;它是一套现成的工…