从头开始构建一个小规模的文生视频模型

news2025/1/10 19:59:34

OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。

在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。

由于作者没有大算力的 GPU,所以仅编写了小规模架构。以下是在不同处理器上训练模型所需时间的比较。

图片

作者表示,在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。因此建议使用 Colab 或 Kaggle 的 T4 GPU 进行更高效、更快速的训练。

技术交流&资料

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

成立了算法面试和技术交流群,相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:来自CSDN + 技术交流

构建目标

我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上进行训练,然后在未见过数据上进行测试。在文本转视频的背景下,假设有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。

图片

图源:iStock, GettyImages

虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将使用由 Python 代码生成的移动对象视频数据集。同时使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。

我们也尝试使用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地进行训练和测试。

准备条件

我们将使用 OOP(面向对象编程),因此必须对它以及神经网络有基本的了解。此外 GAN(生成对抗网络)的知识不是必需的,因为这里简单介绍它们的架构。

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM

  • 神经网络理论:https://www.youtube.com/watch?v=Jy4wM2X21u0

  • GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco

  • Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架构

什么是 GAN?

生成对抗网络是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集创建新数据(如图像或音乐),另一个则判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。

真实世界应用

  • 生成图像:GAN 根据文本 prompt 创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。

  • 数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。

  • 补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图像以用于能源应用。

  • 生成 3D 模型:GAN 将 2D 图像转换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创建逼真的器官图像。

GAN 工作原理

GAN 由两个深度神经网络组成:生成器和判别器。这两个网络在对抗设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真是假。

图片

GAN 训练示例

让我们以图像到图像的转换为例,解释一下 GAN 模型,重点是修改人脸。

1. 输入图像:输入图像是一张真实的人脸图像。

2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。

3. 生成图像:生成器会创建一组添加了太阳镜的图像。

4. 判别器的任务:判别器接收到混合的真实图像(带有太阳镜的人)和生成的图像(添加了太阳镜的人脸)。

5. 评估:判别器尝试区分真实图像和生成图像。

6. 反馈回路:如果判别器正确识别出假图像,生成器会调整其参数以生成更逼真的图像。如果生成器成功欺骗了判别器,判别器会更新其参数以提高检测能力。

通过这一对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而判别器则越来越善于识别假图像,直到达到平衡,判别器再也无法区分真实图像和生成的图像。此时,GAN 已成功学会生成逼真的修改图像。

设置背景

我们将使用一系列 Python 库,让我们导入它们。

# Operating System module for interacting with the operating system
import os

# Module for generating random numbers
import random

# Module for numerical operations
import numpy as np

# OpenCV library for image processing
import cv2

# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont

# PyTorch library for deep learning
import torch

# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset

# Module for image transformations
import torchvision.transforms as transforms

# Neural network module in PyTorch
import torch.nn as nn

# Optimization algorithms in PyTorch
import torch.optim as optim

# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence

# Function for saving images in PyTorch
from torchvision.utils import save_image

# Module for plotting graphs and images
import matplotlib.pyplot as plt

# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML

# Module for encoding and decoding binary data to text
import base64

现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。

对训练数据进行编码

我们需要至少 10000 个视频作为训练数据。为什么呢?因为我测试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么? 我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式移动的视频。让我们来编写代码并生成 10,000 个视频,看看它的效果如何。

# Create a directory named 'training_dataset'
os.makedirs('training_dataset', exist_ok=True)

# Define the number of videos to generate for the dataset
num_videos = 10000

# Define the number of frames per video (1 Second Video)
frames_per_video = 10

# Define the size of each image in the dataset
img_size = (64, 64)

# Define the size of the shapes (Circle)
shape_size = 10 

设置一些基本参数后,接下来我们需要定义训练数据集的文本 prompt,并据此生成训练视频。

# Define text prompts and corresponding movements for circles
prompts_and_movements = [
 ("circle moving down", "circle", "down"), # Move circle downward
 ("circle moving left", "circle", "left"), # Move circle leftward
 ("circle moving right", "circle", "right"), # Move circle rightward
 ("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right
 ("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left
 ("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left
 ("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right
 ("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise
 ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise
 ("circle shrinking", "circle", "shrink"), # Shrink circle
 ("circle expanding", "circle", "expand"), # Expand circle
 ("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically
 ("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally
 ("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically
 ("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally
 ("circle moving up-left", "circle", "up_left"), # Move circle up-left
 ("circle moving down-right", "circle", "down_right"), # Move circle down-right
 ("circle moving down-left", "circle", "down_left"), # Move circle down-left
]

我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据 prompt 移动圆。

# Define function with parameters
def create_image_with_moving_shape(size, frame_num, shape, direction):
 
 # Create a new RGB image with specified size and white background
 img = Image.new('RGB', size, color=(255, 255, 255)) 

 # Create a drawing context for the image
 draw = ImageDraw.Draw(img) 

 # Calculate the center coordinates of the image
 center_x, center_y = size[0] // 2, size[1] // 2 

 # Initialize position with center for all movements
 position = (center_x, center_y) 

 # Define a dictionary mapping directions to their respective position adjustments or image transformations
 direction_map = { 
 # Adjust position downwards based on frame number
 "down": (0, frame_num * 5 % size[1]), 
 # Adjust position to the left based on frame number
 "left": (-frame_num * 5 % size[0], 0), 
 # Adjust position to the right based on frame number
 "right": (frame_num * 5 % size[0], 0), 
 # Adjust position diagonally up and to the right
 "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position diagonally down and to the left
 "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]), 
 # Adjust position diagonally up and to the left
 "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position diagonally down and to the right
 "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), 
 # Rotate the image clockwise based on frame number
 "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), 
 # Rotate the image counter-clockwise based on frame number
 "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), 
 # Adjust position for a bouncing effect vertically
 "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)), 
 # Adjust position for a bouncing effect horizontally
 "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0), 
 # Adjust position for a zigzag effect vertically
 "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]), 
 # Adjust position for a zigzag effect horizontally
 "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y), 
 # Adjust position upwards and to the right based on frame number
 "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position upwards and to the left based on frame number
 "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), 
 # Adjust position downwards and to the right based on frame number
 "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), 
 # Adjust position downwards and to the left based on frame number
 "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]) 
 }

 # Check if direction is in the direction map
 if direction in direction_map: 
 # Check if the direction maps to a position adjustment
 if isinstance(direction_map[direction], tuple): 
 # Update position based on the adjustment
 position = tuple(np.add(position, direction_map[direction])) 
 else: # If the direction maps to an image transformation
 # Update the image based on the transformation
 img = direction_map[direction] 

 # Return the image as a numpy array
 return np.array(img)

上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。

# Iterate over the number of videos to generate
for i in range(num_videos):
    # Randomly choose a prompt and movement from the predefined list
    prompt, shape, direction = random.choice(prompts_and_movements)
    
    # Create a directory for the current video
    video_dir = f'training_dataset/video_{i}'
    os.makedirs(video_dir, exist_ok=True)
    
    # Write the chosen prompt to a text file in the video directory
    with open(f'{video_dir}/prompt.txt', 'w') as f:
        f.write(prompt)
    
    # Generate frames for the current video
    for frame_num in range(frames_per_video):
        # Create an image with a moving shape based on the current frame number, shape, and direction
        img = create_image_with_moving_shape(img_size, frame_num, shape, direction)
        
        # Save the generated image as a PNG file in the video directory
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。

图片

每个训练视频文件夹包含其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。

在我们的训练数据集中,我们没有包含圆圈先向上移动然后向右移动的运动。我们将使用这个作为测试 prompt,来评估我们训练的模型在未见过的数据上的表现。

图片

还有一个重要的要点需要注意,我们的训练数据包含许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。

图片

在我们的训练数据中包含此类样本的原因是为了测试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。

现在我们的训练数据已经生成,需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等转换有助于提高训练架构的收敛性和稳定性。

预处理训练数据

我们必须为文本转视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中使用。

# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Initialize the dataset with root directory and optional transform
        self.root_dir = root_dir
        self.transform = transform
        # List all subdirectories in the root directory
        self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        # Initialize lists to store frame paths and corresponding prompts
        self.frame_paths = []
        self.prompts = []

        # Loop through each video directory
        for video_dir in self.video_dirs:
            # List all PNG files in the video directory and store their paths
            frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            # Read the prompt text file in the video directory and store its content
            with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:
                prompt = f.read().strip()
            # Repeat the prompt for each frame in the video and store in prompts list
            self.prompts.extend([prompt] * len(frames))

    # Return the total number of samples in the dataset
    def __len__(self):
        return len(self.frame_paths)

    # Retrieve a sample from the dataset given an index
    def __getitem__(self, idx):
        # Get the path of the frame corresponding to the given index
        frame_path = self.frame_paths[idx]
        # Open the image using PIL (Python Imaging Library)
        image = Image.open(frame_path)
        # Get the prompt corresponding to the given index
        prompt = self.prompts[idx]

        # Apply transformation if specified
        if self.transform:
            image = self.transform(image)

        # Return the transformed image and the prompt
        return image, prompt

在继续编写架构代码之前,我们需要对训练数据进行归一化处理。我们使用 16 的 batch 大小并对数据进行混洗以引入更多随机性。

实现文本嵌入层

你可能已经看到,在 Transformer 架构中,起点是将文本输入转换为嵌入,从而在多头注意力中进行进一步处理。类似地,我们在这里必须编写一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图像张量上进行。

# Define a class for text embedding
class TextEmbedding(nn.Module):
    # Constructor method with vocab_size and embed_size parameters
    def __init__(self, vocab_size, embed_size):
        # Call the superclass constructor
        super(TextEmbedding, self).__init__()
        # Initialize embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)

    # Define the forward pass method
    def forward(self, x):
        # Return embedded representation of input
        return self.embedding(x) 

词汇量将基于我们的训练数据,在稍后进行计算。嵌入大小将为 10。如果使用更大的数据集,你还可以使用 Hugging Face 上已有的嵌入模型。

实现生成器层

现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层进行编码,然后了解其内容。

class Generator(nn.Module):
 def __init__(self, text_embed_size):
 super(Generator, self).__init__()
 
 # Fully connected layer that takes noise and text embedding as input
 self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)
 
 # Transposed convolutional layers to upsample the input
 self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
 self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
 self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images
 
 # Activation functions
 self.relu = nn.ReLU(True) # ReLU activation function
 self.tanh = nn.Tanh() # Tanh activation function for final output

 def forward(self, noise, text_embed):
 # Concatenate noise and text embedding along the channel dimension
 x = torch.cat((noise, text_embed), dim=1)
 
 # Fully connected layer followed by reshaping to 4D tensor
 x = self.fc1(x).view(-1, 256, 8, 8)
 
 # Upsampling through transposed convolution layers with ReLU activation
 x = self.relu(self.deconv1(x))
 x = self.relu(self.deconv2(x))
 
 # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
 x = self.tanh(self.deconv3(x))
 
 return x

该 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧,旨在根据给定的文本描述生成逼真的视频帧。该网络从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。

这些层使用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层使用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输入转换为以视觉方式表示输入文本的连贯视频帧。

实现判别器层

在编写完生成器层之后,我们需要实现另一半,即判别器部分。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Convolutional layers to process input images
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1
        
        # Fully connected layer for classification
        self.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)
        
        # Activation functions
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # Leaky ReLU activation with negative slope 0.2
        self.sigmoid = nn.Sigmoid()  # Sigmoid activation for final output (probability)

    def forward(self, input):
        # Pass input through convolutional layers with LeakyReLU activation
        x = self.leaky_relu(self.conv1(input))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))
        
        # Flatten the output of convolutional layers
        x = x.view(-1, 256 * 8 * 8)
        
        # Pass through fully connected layer with Sigmoid activation for binary classification
        x = self.sigmoid(self.fc1(x))
        
        return x

判别器类用作二元分类器,区分真实视频帧和生成的视频帧。目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 (nn.Conv2d) 组成,这些卷积层从输入视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。

然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真实还是假的概率分数。

通过训练判别器准确地对帧进行分类,生成器同时接受训练以创建更令人信服的视频帧,从而骗过判别器。

编写训练参数

我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index
vocab_size = len(vocab)  # Size of the vocabulary
embed_size = 10  # Size of the text embedding vector

def encode_text(prompt):
    # Encode a given prompt into a tensor of indices using the vocabulary
    return torch.tensor([vocab[word] for word in prompt.split()])

# Initialize models, loss function, and optimizers
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # Initialize TextEmbedding model with vocab_size and embed_size
netG = Generator(embed_size).to(device)  # Initialize Generator model with embed_size
netD = Discriminator().to(device)  # Initialize Discriminator model
criterion = nn.BCELoss().to(device)  # Binary Cross Entropy loss function
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Generator

这是我们必须转换代码以在 GPU 上运行的部分(如果可用)。我们已经编写了代码来查找 vocab_size,并且我们正在为生成器和判别器使用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众使用的 Hugging Face 模型要小得多。

编写训练 loop

就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。

# Number of epochs
num_epochs = 13

# Iterate over each epoch
for epoch in range(num_epochs):
    # Iterate over each batch of data
    for i, (data, prompts) in enumerate(dataloader):
        # Move real data to device
        real_data = data.to(device)
        
        # Convert prompts to list
        prompts = [prompt for prompt in prompts]

        # Update Discriminator
        netD.zero_grad()  # Zero the gradients of the Discriminator
        batch_size = real_data.size(0)  # Get the batch size
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)
        output = netD(real_data)  # Forward pass real data through Discriminator
        lossD_real = criterion(output, labels)  # Calculate loss on real data
        lossD_real.backward()  # Backward pass to calculate gradients
       
        # Generate fake data
        noise = torch.randn(batch_size, 100).to(device)  # Generate random noise
        text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddings
        fake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddings
        labels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)
        output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)
        lossD_fake = criterion(output, labels)  # Calculate loss on fake data
        lossD_fake.backward()  # Backward pass to calculate gradients
        optimizerD.step()  # Update Discriminator parameters

        # Update Generator
        netG.zero_grad()  # Zero the gradients of the Generator
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminator
        output = netD(fake_data)  # Forward pass fake data (now updated) through Discriminator
        lossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's response
        lossG.backward()  # Backward pass to calculate gradients
        optimizerG.step()  # Update Generator parameters
    
    # Print epoch information
    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

通过反向传播,我们的损失将针对生成器和判别器进行调整。我们在训练 loop 中使用了 13 个 epoch。我们测试了不同的值,但如果 epoch 高于这个值,结果并没有太大差异。此外,过度拟合的风险很高。如果我们的数据集更加多样化,包含更多动作和形状,则可以考虑使用更高的 epoch,但在这里没有这样做。

当我们运行此代码时,它会开始训练,并在每个 epoch 之后 print 生成器和判别器的损失。

## OUTPUT ##

Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996
Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648
Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776

...

保存训练的模型

训练完成后,我们需要保存训练好的 GAN 架构的判别器和生成器,这只需两行代码即可实现。

# Save the Generator model's state dictionary to a file named 'generator.pth'
torch.save(netG.state_dict(), 'generator.pth')

# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'
torch.save(netD.state_dict(), 'discriminator.pth')

生成 AI 视频

正如我们所讨论的,我们在未见过的数据上测试模型的方法与我们训练数据中涉及狗取球和猫追老鼠的示例类似。因此,我们的测试 prompt 可能涉及猫取球或狗追老鼠等场景。

在我们的特定情况下,圆圈向上移动然后向右移动的运动在训练数据中不存在,因此模型不熟悉这种特定运动。但是,模型已经在其他动作上进行了训练。我们可以使用此动作作为 prompt 来测试我们训练过的模型并观察其性能。

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10):    # Create a directory for the generated video frames based on the text prompt    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)        # Encode the text prompt into a text embedding tensor    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)        # Generate frames for the video    for frame_num in range(num_frames):        # Generate random noise        noise = torch.randn(1, 100).to(device)                # Generate a fake frame using the Generator network        with torch.no_grad():            fake_frame = netG(noise, text_embed)                # Save the generated fake frame as an image file        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')

当我们运行上述代码时,它将生成一个目录,其中包含我们生成视频的所有帧。我们需要使用一些代码将所有这些帧合并为一个短视频。

# Define the path to your folder containing the PNG frames
folder_path = 'generated_video_circle_moving_up-right'


# Get the list of all PNG files in the folder
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

# Sort the images by name (assuming they are numbered sequentially)
image_files.sort()

# Create a list to store the frames
frames = []

# Read each image and append it to the frames list
for image_file in image_files:
 image_path = os.path.join(folder_path, image_file)
 frame = cv2.imread(image_path)
 frames.append(frame)

# Convert the frames list to a numpy array for easier processing
frames = np.array(frames)

# Define the frame rate (frames per second)
fps = 10

# Create a video writer object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))

# Write each frame to the video
for frame in frames:
 out.write(frame)

# Release the video writer
out.release()

确保文件夹路径指向你新生成的视频所在的位置。运行此代码后,你将成功创建 AI 视频。让我们看看它是什么样子。

图片

我们进行了多次训练,训练次数相同。在两种情况下,圆圈都是从底部开始,出现一半。好消息是,我们的模型在两种情况下都尝试执行直立运动。

例如,在尝试 1 中,圆圈沿对角线向上移动,然后执行向上运动,而在尝试 2 中,圆圈沿对角线移动,同时尺寸缩小。在两种情况下,圆圈都没有向左移动或完全消失,这是一个好兆头。

最后,作者表示已经测试了该架构的各个方面,发现训练数据是关键。通过在数据集中包含更多动作和形状,你可以增加可变性并提高模型的性能。由于数据是通过代码生成的,因此生成更多样的数据不会花费太多时间;相反,你可以专注于完善逻辑。

此外,文章中讨论的 GAN 架构相对简单。你可以通过集成高级技术或使用语言模型嵌入 (LLM) 而不是基本神经网络嵌入来使其更复杂。此外,调整嵌入大小等参数会显著影响模型的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Web后端开发之前后端交互

http协议 http ● 超文本传输协议 (HyperText Transfer Protocol)服务器传输超文本到本地浏览器的传送协议 是互联网上应用最为流行的一种网络协议,用于定义客户端浏览器和服务器之间交换数据的过程。 HTTP是一个基于TCP/IP通信协议来传递数据. HTT…

成绩发布背后:老师的无奈与痛点

在教育的广阔天地里,教师这一角色承载着无数的期望与责任。他们不仅是知识的传播者,更是学生心灵的引路人。而对于班主任老师来说,他们的角色更加多元,他们不仅是老师,还必须是“妈妈”。除了像其他老师一样备课、上课…

Web3 前端攻击:原因、影响及经验教训

DeFi的崛起引领了一个创新和金融自由的新时代。然而,这种快速增长也吸引了恶意行为者的注意,他们试图利用漏洞进行攻击。尽管很多焦点都集中在智能合约安全上,但前端攻击也正在成为一个重要的威胁向量。 前端攻击的剖析 理解攻击者利用前端漏…

MaxKb/open-webui+Ollama运行模型

准备:虚拟机:centos7 安装Docker:首先,需要安装Docker,因为Ollama和MaxKB都是基于Docker的容器。使用以下命令安装Docker: sudo yum install -y yum-utils device-mapper-persistent-data lvm2 sudo yum…

Keil汇编相关知识

一、汇编的组成 1.汇编指令:在内存中占用内存,执行一条汇编指令会让处理器进行相关运算 分类:数据处理指令,跳转指令,内存读写指令,状态寄存器传送指令,软中断产生指令,协助处理器…

生成式AI如何赋能教育?商汤发布《2024生成式AI赋能教育未来》白皮书

生成式AI正在各个行业中展现出巨大的应用前景。在关系国计民生的教育行业,生成式AI能够催生哪些创新模式? 6月28日,商汤科技受邀参加2024中国AIGC应用与发展峰会,并在会上发布《2024生成式AI赋能教育未来》白皮书,提出…

Django之阿里云短信

短信验证 短信验证,首先得选择一个短信发送服务器上,本文档使用阿里云实现短信发送功能 阿里云短信网 网址:短信服务_企业短信营销推广_验证码通知-阿里云 注册账号 新账号赠送100条,可以不用充值,即可进行测试 接入 短信 进行 个人实名认证 编写代码执行 安装依赖模块 p…

html5 video去除边框

video的属性: autoplay 视频在就绪后自动播放。 controls 显示控件,比如播放按钮。 height 设置视频播放器的高度。 width 设置视频播放器的宽度。 loop 循环播放 muted 视频的音频输出静音。 poster 视频加载时显示的图像,或者在用户点击播…

全球最大智能立体书库|北京:3万货位,715万册,自动出库、分拣、搬运

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》 北京城市图书馆的立体书库采用了先进的WMS(仓库管理系统)和WCS(仓库控制系统),与图书…

【机器学习】机器学习的重要技术——生成对抗网络:理论、算法与实践

引言 生成对抗网络(Generative Adversarial Networks, GANs)由Ian Goodfellow等人在2014年提出,通过生成器和判别器两个神经网络的对抗训练,成功实现了高质量数据的生成。GANs在图像生成、数据增强、风格迁移等领域取得了显著成果…

老师期末工作怎么减负?

期末,一个学期的尾声,也是老师们最为忙碌的时刻。在这段时间里,我们不仅要完成教学任务,还要准备期末考试、批改试卷、撰写学生评语、制定假期计划等一系列繁重的工作。那么,如何在这样紧张的期末工作中为自己减负呢&a…

Android高级面试_2_IPC相关

Android 高级面试-3:语言相关 1、Java 相关 1.1 缓存相关 问题:LruCache 的原理? 问题:DiskLruCache 的原理? LruCache 用来实现基于内存的缓存,LRU 就是最近最少使用的意思,LruCache 基于L…

RocketMQ源码学习笔记:Producer启动流程

这是本人学习的总结,主要学习资料如下 马士兵教育rocketMq官方文档 目录 1、Overview1.1、创建MQClientInstance1.1.1、检查1.1.1、MQClientInstance的ID 1.2、MQClientInstance.start() 1、Overview 这是发送信息的代码样例, DefaultMQProducer produ…

百强韧劲,进击新局 2023年度中国医药工业百强系列榜单发布

2024年,经济工作坚持稳中求进、以进促稳、先立后破等工作要求。医药健康行业以不懈进取的“韧劲”,立身破局,迎变启新。通过创新和迭代应对不确定性,进化韧性力量,坚持高质量发展,把握新时代经济和社会给予…

Python之三大基本库——Numpy(2)

接着上次的内容接着讲,连续号都续上哈 七、numpu中random的随机生成函数 以下总结的是比较常用到的函数: 下面分别介绍一下不用的用法: 首先导入创建函数 import numpy as np np.random.seed(666)1、 rand(d0,d1,d2,...,dn):返…

Keysight是德 N9912A 手持式射频分析仪

Keysight是德 N9912A 手持式频谱分析仪 N9912A FieldFox 手持射频分析仪,4 GHz 和 6 GHz 轻盈耐用的电缆与天线分析仪、频谱分析仪、网络分析仪等等。 Keysight FieldFox 便携式分析仪可以在非常恶劣的工作环境中,轻松完成从日常维护到深入故障诊断的…

【Python】 模型训练数据归一化的原理

那年夏天我和你躲在 这一大片宁静的海 直到后来我们都还在 对这个世界充满期待 今年冬天你已经不在 我的心空出了一块 很高兴遇见你 让我终究明白 回忆比真实精彩 🎵 王心凌《那年夏天宁静的海》 在机器学习和深度学习中,数据归一化…

如何用DCA1000持续采集雷达数据

摘要:本文介绍一下如何通过mmwave studio软件,搭配DCA1000数据采集卡,对AWR1843BOOST进行不间断的数据采集。本文要求读者已经掌握了有关基础知识。 本文开放获取,无需关注。 到SensorConfig页面下,一步步操作&#xf…

吉时利 Keithley2601B-PULSE 脉冲数字源表

Keithley2601B-PULSE吉时利脉冲SMU数字源表 无需手动脉冲调整即可实现高脉冲保真度 通过 2601B-PULSE 控制回路系统,高达 3μH 的负载变化无需手动调整,从而确保在任何电流水平(最高 10 安培)下输出 10 μs 至 500 μs 脉冲时&a…

柯桥法语学习|学点黑话!法语中的「钱」可不止“argent”

法语中有哪些关于钱的“黑话”?一起来和法语君看一下吧! bl 之所以繁杂,是因为这些词在诞生之初,不止涉及一个故事,而是一大堆小轶事,以“bl”指钱的起源如迷宫般复杂。 根据Trsor de la langue frana15857…