2024Datawhale AI夏令营---Inclusion・The Global Multimedia Deepfake Detection--学习笔记

news2024/10/30 9:42:21

赛题背景:

        其实总结起来就是一句话,这个项目是基于目前的深度伪装技术,就是通过大量人脸的原数据集进行模型训练之后,能够生成伪造的人脸视频。这项目就是教我们如何去实现这个DeepFake技术。

Task1:了解Deepfake和跑通baseline

代码架构如下:

  1. 模型定义:使用timm库创建一个预训练的resnet18模型。

  2. 训练/验证数据加载:使用torch.utils.data.DataLoader来加载训练集和验证集数据,并通过定义的transforms进行数据增强。

  3. 训练与验证过程

    1. 定义了train函数来执行模型在一个epoch上的训练过程,包括前向传播、损失计算、反向传播和参数更新。

    2. 定义了validate函数来评估模型在验证集上的性能,计算准确率。

  4. 性能评估:使用准确率(Accuracy)作为性能评估的主要指标,并在每个epoch后输出验证集上的准确率。

  5. 提交:最后,将预测结果保存到CSV文件中,准备提交到Kaggle比赛。

代码解释如下:

详见代码注释吧

这份代码后续还是要好好精读理解一下的吧,好好分析一下,顺便提升一下代码能力。--7.15

from PIL import Image
Image.open('/kaggle/input/deepfake/phase1/trainset/63fee8a89581307c0b4fd05a48e0ff79.jpg')

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import timm
import time

import pandas as pd
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm_notebook

train_label = pd.read_csv('/kaggle/input/deepfake/phase1/trainset_label.txt')
val_label = pd.read_csv('/kaggle/input/deepfake/phase1/valset_label.txt')

train_label['path'] = '/kaggle/input/deepfake/phase1/trainset/' + train_label['img_name']
val_label['path'] = '/kaggle/input/deepfake/phase1/valset/' + val_label['img_name']

train_label['target'].value_counts()

val_label['target'].value_counts()

train_label.head(10)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, *meters):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = ""


    def pr2int(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
            losses.update(loss.item(), input.size(0))
            top1.update(acc, input.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f}'
              .format(top1=top1))
        return top1

def predict(test_loader, model, tta=10):
    # switch to evaluate mode
    model.eval()
    
    test_pred_tta = None
    for _ in range(tta):
        test_pred = []
        with torch.no_grad():
            end = time.time()
            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):
                input = input.cuda()
                target = target.cuda()

                # compute output
                output = model(input)
                output = F.softmax(output, dim=1)
                output = output.data.cpu().numpy()

                test_pred.append(output)
        test_pred = np.vstack(test_pred)
    
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    
    return test_pred_tta

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
        top1.update(acc, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            progress.pr2int(i)

class FFDIDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, torch.from_numpy(np.array(self.img_label[index]))
    
    def __len__(self):
        return len(self.img_path)

import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=2)
model = model.cuda()

train_loader = torch.utils.data.DataLoader(
    FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    FFDIDataset(val_label['path'].head(1000), val_label['target'].head(1000), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(2):
    scheduler.step()
    print('Epoch: ', epoch)

    train(train_loader, model, criterion, optimizer, epoch)
    val_acc = validate(val_loader, model, criterion)
    
    if val_acc.avg.item() > best_acc:
        best_acc = round(val_acc.avg.item(), 2)
        torch.save(model.state_dict(), f'./model_{best_acc}.pt')

test_loader = torch.utils.data.DataLoader(
    FFDIDataset(val_label['path'], val_label['target'], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)

val_label['y_pred'] = predict(test_loader, model, 1)[:, 1]
val_label[['img_name', 'y_pred']].to_csv('submit.csv', index=None)

        本来是想直接在本地运行的,但是这个数据集实在是太大了,受限于操作和设备,只能在kaggle云运行这个代码咯,结果如下:

        提交上kaggle进行评分:

Inclusion・The Global Multimedia Deepfake Detection | Kaggle

        结果挺差的,毕竟这就是个普通的原始代码,啥都没优化的,参数也没调,只能先这样咯,后续task再来优化调整咔咔上分吧。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1929647.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电脑鼠标连点工具哪家强?2024最新鼠标连点器工具分享

在现代计算机使用中,鼠标作为最主要的输入设备之一,在日常操作和特定应用中发挥着关键作用。然而,某些任务可能要求用户频繁点击鼠标,这不仅繁琐乏味,还可能导致手部疲劳。为了解决这一问题,自动鼠标点击工…

初始c语言(2)运算符与表达式

一 c语言提供的运算符类型 以上会后续介绍 二 现阶段我们掌握如下的基本操作符 注意!计算机的除法只会保留整数部分(若被除数未负则不同的软件取整的结果不唯一) 三 自加()自减(--)符号 若为…

【区块链 + 智慧政务】区块链 +ETC 下一代公路联网收费关键技术优化项目 | FISCO BCOS应用案例

2020 年,我国取消省界收费站项目完成后,随着收费模式与收费方式的变化,形成了以门架为计费单元的新收 费体系:按照车辆通行门架数,RSU 天线读取 ETC 卡、电子标签 OBU 或 CPC 卡内标识的车型信息,车型门架计…

SQL去重的四种方法

去重是指:查询的时候, 不显示重复,并不是删除表中的重复项 数据表: 方法1:distinct去重 作用:只能一列去重,当distinct后跟大于1个参数时,他们之间的关系是&&(逻辑与)关系,…

vue2学习笔记7 - Vue中的MVVM模型

MVVM Model-View-viewModel是一种软件架构模式,用于将用户界面(View)与业务逻辑(Model)分离,并通过ViewModel进行连接和协调。MVVM模式的目标是实现视图与模型的解耦,提高代码的可读性、可维护…

Qt | 绘制椭圆、弧、弦、扇形、圆角矩形

点击上方"蓝字"关注我们 01、简介 1、需要使用到的 QPainter 类中的函数 2、绘制椭圆的方法有 绘制给定矩形的内接椭圆和根据中心点与椭圆 x 方向和 y 方向的半径绘制,原理见下图 3、绘制弧、弦、扇形的原理: 1)、弧是椭圆上的一段曲线,因此其绘制方法就是首先…

【Apache Doris】周FAQ集锦:第 14 期

【Apache Doris】周FAQ集锦:第 14 期 SQL问题数据操作问题运维常见问题其它问题关于社区 欢迎查阅本周的 Apache Doris 社区 FAQ 栏目! 在这个栏目中,每周将筛选社区反馈的热门问题和话题,重点回答并进行深入探讨。旨在为广大用户…

企业全历史行为数据助ToB企业决策层开启营销的上帝视角

“上帝视角”是每个企业家都渴望拥有的。上帝视角的能力有多么吸引人呢?通常,一个企业家在技术、产品、营销中的任何一个领域拥有上帝视角的能力,就足可以让他的企业大杀四方,甚至创造历史。 在技术或产品领域,靠“上…

10.1 标注、注记图层和注记整体说明

文章目录 前言标注、注记图层和注记QGis中的标注QGis中的注释(Annotation)图层QGis中的注记 总结 前言 介绍标注、注记图层和注记说明:文章中的示例代码均来自开源项目qgis_cpp_api_apps 标注、注记图层和注记 有时地图需要使用一些文字信息说明其中的地理要素或其…

Android 性能优化之卡顿优化

文章目录 Android 性能优化之卡顿优化卡顿检测TraceView配置缺点 StricktMode配置违规代码 BlockCanary配置问题代码缺点 ANRANR原因ANRWatchDog监测解决方案 Android 性能优化之卡顿优化 卡顿检测 TraceViewStricktModelBlockCanary TraceView 配置 Debug.startMethodTra…

Python中的数据结构:五彩斑斓的糖果盒

在Python编程的世界里,数据结构就像是一个个五彩斑斓的糖果盒,每一种糖果都有其独特的味道和形状。这些多姿多彩,形状和味道各异的糖果盒子包括了:List(列表)、Tuple(元组)、Diction…

Redis主从部署

主从部署 整体架构图 需要再建两个CentOs7,过程重复单机部署 查看自己ip地址命令 ifconfig 192.168.187.137 进入redis所在目录 cd /opt/software/redis cd redis-stable 进入配置文件 vim redis.conf 修改分身1、2的配置文件 搜索replicaof replicaof 192.168.187.137 63…

笔记 2 : 课本第 3 章开始,记录 arm 的汇编指令的格式

(13) 介绍 arm 中的第一个汇编指令的用法 mov : (14)立即数的概念: (15) 汇编中的移位写法: 举例 : (16) 学习一个新的指令 cmp &a…

二叉树相关理论知识

二叉树是计算机科学中一种基础且重要的数据结构,它属于树形结构的一个重要类型。以下是二叉树的理论基础,包括定义、基本形态、特殊类型、性质以及遍历方式等方面的内容。 一、定义 二叉树(Binary Tree)是n(n≥0&…

【实战系列】PostgreSQL 专栏,基于 PostgreSQL 16 版本

我的 PostgreSQL 专栏介绍及进度 20240715:目前整体进度已完成 85%,完成 16 万字,还有近 5 万字就截稿了。 (venv312) ➜ mypostgres git:(dev) sh scripts/word_statistics_pg_style.sh Filename …

15分钟快速了解图新地球能做什么,解决什么问题,快速入门

1.图新地球桌面端是什么 1.1官方定义 图新地球桌面端(LSV)是一款集多源数据加载、应用分析、演示汇报为一体的三维GIS 软件。采用了中科图新自主研发的国产三维地图引擎,支持各类无人机航测、CAD、BIM、规划成果等多源数据的加载融合;实现了BIMGIS 技术在实际业务…

所有权与生命周期:Rust 内存管理的哲学

所有权与生命周期:Rust内存管理的哲学 博主寄语引言:编程语言的内存管理困境与 Rust 的解决方案。所有权基本概念:资源的绝对主权生命周期的理解与应用:编译时的守护神借用与引用的精妙设计:安全与效率的和谐共舞Rust …

Golang | Leetcode Golang题解之第231题2的幂

题目&#xff1a; 题解&#xff1a; func isPowerOfTwo(n int) bool {const big 1 << 30return n > 0 && big%n 0 }

C# Opencv实现本地以图搜图

地址&#xff1a;冯腾飞/本地以图搜图

shell脚本变量和运算

一、shell变量及赋值 1.1、shell的变量 变量是用来临时保存数据的&#xff0c;并且该数据时可以变化的&#xff0c;任何一个语言都离不开变量&#xff0c;如果某个内容需要多次使用并且会重复出现&#xff0c;这样就可以使用变量了&#xff0c;如果需要修改直接修改变量就可以…