【Datawhale AI 夏令营】CV图像竞赛——Deepfake攻防

news2024/9/23 0:23:19

【Datawhale AI 夏令营】CV图像竞赛——Deepfake攻防

从零入门CV图像竞赛(Deepfake攻防) 是 Datawhale 2024 年 AI 夏令营第二期 的学习活动(“CV图像”方向),基于蚂蚁集团举办的“外滩大会-全球Deepfake攻防挑战赛”开展的实践学习

​ 这几天参加了Datawhale AI 夏令营的CV图像竞赛,跟随DataWhale的学习指南,跑通Baseline,在Kaggle平台完成整个Deepfake攻防挑战赛。下面来介绍本次活动的内容及收获。

1 Deepfake攻防任务介绍

​ 随着人工智能技术的迅猛发展,深度伪造技术(Deepfake)正成为数字世界中的一把双刃剑。Deepfake技术可以通过人工智能算法生成高度逼真的图像、视频和音频内容,这些内容看起来与真实的毫无二致。

1.1 赛题任务

​ Deepfake是一种使用人工智能技术生成的伪造媒体,特别是视频和音频,它们看起来或听起来非常真实,但实际上是由计算机生成的。这种技术通常涉及到深度学习算法,特别是生成对抗网络(GANs),它们能够学习真实数据的特征,并生成新的、逼真的数据。

深度伪造技术通常可以分为四个主流研究方向:

  • 面部交换专注于在两个人的图像之间执行身份交换;
  • 面部重演强调转移源运动和姿态;
  • 说话面部生成专注于在角色生成中实现口型与文本内容的自然匹配;
  • 面部属性编辑旨在修改目标图像的特定面部属性;

1721448504326

​ 本次Deepfake任务目标即为训练模型,判断一张人脸图像是否为Deepfake图像,并输出其为Deepfake图像的概率评分。

1721444215400

Deepfake图片示例

1.2 赛题数据集

  1. 第一阶段

    在第一阶段,主办方将发布训练集和验证集。参赛者将使用训练集 (train_label.txt) 来训练模型,而验证集 (val_label.txt) 仅用于模型调优。文件的每一行包含两个部分,分别是图片文件名和标签值(label=1 表示Deepfake图像,label=0 表示真实人脸图像)。例如:

    train_label.txt

    img_name,target
    3381ccbc4df9e7778b720d53a2987014.jpg,1
    63fee8a89581307c0b4fd05a48e0ff79.jpg,0
    7eb4553a58ab5a05ba59b40725c903fd.jpg,0
    …
    

    val_label.txt

    img_name,target
    cd0e3907b3312f6046b98187fc25f9c7.jpg,1
    aa92be19d0adf91a641301cfcce71e8a.jpg,0
    5413a0b706d33ed0208e2e4e2cacaa06.jpg,0
    …
    
  2. 第二阶段

    在第一阶段结束后,主办方将发布测试集。在第二阶段,参赛者需要在系统中提交测试集的预测评分文件 (prediction.txt),主办方将在线反馈测试评分结果。文件的每一行包含两个部分,分别是图片文件名和模型预测的Deepfake评分(即样本属于Deepfake图像的概率值)。例如:

    prediction.txt

    img_name,y_pred
    cd0e3907b3312f6046b98187fc25f9c7.jpg,1
    aa92be19d0adf91a641301cfcce71e8a.jpg,0.5
    5413a0b706d33ed0208e2e4e2cacaa06.jpg,0.5
    …
    
  3. 第三阶段

    在第二阶段结束后,前30名队伍将晋级到第三阶段。在这一阶段,参赛者需要提交代码docker和技术报告。Docker要求包括原始训练代码和测试API(函数输入为图像路径,输出为模型预测的Deepfake评分)。主办方将检查并重新运行算法代码,以重现训练过程和测试结果。

1.3 评价指标

​ 比赛的性能评估主要使用ROC曲线下的AUC(Area under the ROC Curve)作为指标。AUC的取值范围通常在0.5到1之间。若AUC指标不能区分排名,则会使用TPR@FPR=1E-3作为辅助参考。

  1. 真阳性率 (TPR):
    T P R = T P / ( T P + F N ) TPR = TP / (TP + FN) TPR=TP/(TP+FN)

  2. 假阳性率 (FPR):
    F P R = F P / ( F P + T N ) FPR = FP / (FP + TN) FPR=FP/(FP+TN)
    其中:

    • TP:攻击样本被正确识别为攻击;
    • TN:真实样本被正确识别为真实;
    • FP:真实样本被错误识别为攻击;
    • FN:攻击样本被错误识别为真实。

2 Baseline实现

​ 本次项目在Kaggle平台进行,DataWhale为学习者提供了跑通整个项目的 baseline 代码,并给出了详细的学习指南。下面结合学习指南内容,介绍基础Baseline代码及训练步骤。

2.1 代码介绍

​ Baseline代码,采用了 timm 库来进行图像模型的训练和推理。

  1. 指标计算与显示

    • AverageMeter

      AverageMeter类用于计算和存储指标的平均值和当前值。它通常用于跟踪训练过程中每个epoch或batch的损失值、精度等。

      class AverageMeter(object):
          """计算和存储指标的平均值和当前值"""
          def __init__(self, name, fmt=':f'):
              self.name = name
              self.fmt = fmt
              self.reset()
      	# 重置所有值
          def reset(self):
              self.val = 0
              self.avg = 0
              self.sum = 0
              self.count = 0
      	# 更新当前值
          def update(self, val, n=1):
              self.val = val
              self.sum += val * n
              self.count += n
              self.avg = self.sum / self.count
      	# 返回格式化字符串,显示当前值和平均值
          def __str__(self):
              fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
              return fmtstr.format(**self.__dict__)
      
    • ProgressMeter

      ProgressMeter类用于显示训练过程中各个batch的进度和指标。它通常与AverageMeter类一起使用,方便地显示和跟踪多个指标。

      class ProgressMeter(object):
          def __init__(self, num_batches, *meters):
              self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
              self.meters = meters
              self.prefix = ""
      	# 打印当前batch的进度和所有指标的状态
          def pr2int(self, batch):
              entries = [self.prefix + self.batch_fmtstr.format(batch)]
              entries += [str(meter) for meter in self.meters]
              print('\t'.join(entries))
      	# 根据总batch数生成格式字符串
          def _get_batch_fmtstr(self, num_batches):
              num_digits = len(str(num_batches // 1))
              fmt = '{:' + str(num_digits) + 'd}'
              return '[' + fmt + '/' + fmt.format(num_batches) + ']'
      
  2. 验证、预测和训练神经网络模型

    • validate 函数

      validate函数用于在验证集上评估模型性能。

      def validate(val_loader, model, criterion):
          # 使用AverageMeter类创建计量器,用于跟踪时间、损失和准确度
          batch_time = AverageMeter('Time', ':6.3f')
          losses = AverageMeter('Loss', ':.4e')
          top1 = AverageMeter('Acc@1', ':6.2f')
          progress = ProgressMeter(len(val_loader), batch_time, losses, top1)
      
          # 切换到评估模式
          model.eval()
      
          with torch.no_grad():
              end = time.time()
              # 遍历验证集中的每个batch,计算输出和损失,并更新计量器
              for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):
                  input = input.cuda()
                  target = target.cuda()
      
                  # 计算输出与loss
                  output = model(input)
                  loss = criterion(output, target)
      
                  # 计算accuracy并更新loss
                  acc = (output.argmax(1).view(-1) == target).float().mean() * 100
                  losses.update(loss.item(), input.size(0))
                  top1.update(acc, input.size(0))
                  # 计算运行时间
                  batch_time.update(time.time() - end)
                  end = time.time()
      
              print(' * Acc@1 {top1.avg:.3f}'
                    .format(top1=top1))
              return top1
      
    • predict 函数

      predict函数用于在测试集上进行预测,支持Test-Time Augmentation (TTA)。

      def predict(test_loader, model, tta=10):
          # 切换到评估模式
          model.eval()
          
          test_pred_tta = None
          for _ in range(tta):
              test_pred = []
              with torch.no_grad():
                  for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):
                      input = input.cuda()
                      target = target.cuda()
      
                      # 计算输出
                      output = model(input)
                      output = F.softmax(output, dim=1)
                      output = output.data.cpu().numpy()
      
                      test_pred.append(output)
              test_pred = np.vstack(test_pred)
          
              if test_pred_tta is None:
                  test_pred_tta = test_pred
              else:
                  test_pred_tta += test_pred
          
          return test_pred_tta / tta
      
    • train 函数

      train函数用于在训练集上训练模型。

      def train(train_loader, model, criterion, optimizer, epoch):
          # 使用AverageMeter类创建计量器
          batch_time = AverageMeter('Time', ':6.3f')
          losses = AverageMeter('Loss', ':.4e')
          top1 = AverageMeter('Acc@1', ':6.2f')
          progress = ProgressMeter(len(train_loader), batch_time, losses, top1)
      
          # 切换到训练模式
          model.train()
      
          end = time.time()
          for i, (input, target) in enumerate(train_loader):
              input = input.cuda(non_blocking=True)
              target = target.cuda(non_blocking=True)
      
              # 计算输出
              output = model(input)
              loss = criterion(output, target)
      
              # 计算accuracy并更新loss
              losses.update(loss.item(), input.size(0))
      
              acc = (output.argmax(1).view(-1) == target).float().mean() * 100
              top1.update(acc, input.size(0))
      
              # 计算梯度并更新模型参数
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
      
              # 计算运行时间
              batch_time.update(time.time() - end)
              end = time.time()
      
              if i % 100 == 0:
                  progress.pr2int(i)
      
  3. 自定义数据集类

    用于加载图像数据及其对应的标签,并在获取数据时进行必要的转换。

    class FFDIDataset(Dataset):
        def __init__(self, img_path, img_label, transform=None):
            self.img_path = img_path
            self.img_label = img_label
            
            if transform is not None:
                self.transform = transform
            else:
                self.transform = None
        # 根据索引获取图像及其对应的标签
        def __getitem__(self, index):
            img = Image.open(self.img_path[index]).convert('RGB')
            
            if self.transform is not None:
                img = self.transform(img)
            
            return img, torch.from_numpy(np.array(self.img_label[index]))
        # 返回数据集的大小
        def __len__(self):
            return len(self.img_path)
    
  4. 使用预训练模型进行训练

    • 创建模型

      创建一个预训练的ResNet-18模型,用于二分类任务,并将其移动到GPU上。

      model = timm.create_model('resnet18', pretrained=True, num_classes=2)
      model = model.cuda()
      
    • 创建数据加载器

      # 定义训练集的数据加载器
      train_loader = torch.utils.data.DataLoader(
          FFDIDataset(
              train_label['path'].head(1000), 
              train_label['target'].head(1000), 
              transforms.Compose([
                  transforms.Resize((256, 256)),
                  transforms.RandomHorizontalFlip(),
                  transforms.RandomVerticalFlip(),
                  transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
              ])
          ), 
          batch_size=40, 
          shuffle=True, 
          num_workers=4, 
          pin_memory=True
      )
      
      # 定义验证集的数据加载器
      val_loader = torch.utils.data.DataLoader(
          FFDIDataset(
              val_label['path'].head(1000), 
              val_label['target'].head(1000), 
              transforms.Compose([
                  transforms.Resize((256, 256)),
                  transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
              ])
          ), 
          batch_size=40, 
          shuffle=False, 
          num_workers=4, 
          pin_memory=True
      )
      
    • 定义损失函数、优化器和学习率调度器

      # 定义损失函数,并将其移动到GPU上
      criterion = nn.CrossEntropyLoss().cuda()
      
      # 定义优化器
      optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
      
      # 定义学习率调度器,每4个epoch后学习率乘以0.85
      scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
      
    • 训练和验证模型

      best_acc = 0.0
      
      for epoch in range(2):  # 进行2个epoch的训练
          scheduler.step()  # 更新学习率
          print('Epoch: ', epoch)
      
          # 训练模型
          train(train_loader, model, criterion, optimizer, epoch)
      
          # 在验证集上验证模型
          val_acc = validate(val_loader, model, criterion)
          
          # 如果当前验证准确率超过最佳准确率,保存模型参数
          if val_acc.avg.item() > best_acc:
              best_acc = round(val_acc.avg.item(), 2)
              torch.save(model.state_dict(), f'./model_{best_acc}.pt')
      
    • 预测并保存结果

      # 定义测试集的数据加载器
      test_loader = torch.utils.data.DataLoader(
          FFDIDataset(
              val_label['path'], 
              val_label['target'], 
              transforms.Compose([
                  transforms.Resize((256, 256)),
                  transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
              ])
          ), 
          batch_size=40, 
          shuffle=False, 
          num_workers=4, 
          pin_memory=True
      )
      
      # 进行预测,并将预测结果存储在val_label的y_pred列中
      val_label['y_pred'] = predict(test_loader, model, 1)[:, 1]
      
      # 将结果保存为CSV文件
      val_label[['img_name', 'y_pred']].to_csv('submit.csv', index=None)
      

2.2 Baseline训练

​ 本项目利用Kaggle平台进行训练,训练过程与提交结果在DataWhale学习手册 ‌‌‬‬‌‬‌⁠‍‌‍‌‌‬⁠‌‍⁠‍⁠‬从零入门CV图像竞赛(Deepfake攻防) 中详细给出。

  1. 训练过程

    按照手册进行训练,跑通基础的baseline代码。

    51bb3faaf755e6f0b3638bf0b0395a7

  2. 训练结果

    在Kaggle平台提交训练结果,跑通Baseline得到0.571的得分。

    29ab41852829be93898c7609a965fd5

3 代码优化

​ 通过学习九月大佬的代码 九月0.98\Deepfake-FFDI-Ways to Defeat 0.86 Beseline (kaggle.com),来学习代码优化。

  1. 更换预训练模型

    import timm
    model = timm.create_model('efficientnet_b1', pretrained=True, num_classes=2)
    model = model.cuda()
    
    batch_size_value = 32
    epochs = 2
    

    将Baseline的预训练模型从 resnet18 模型改为 efficientnet_b1 ,训练效果提升明显得到0.97的得分。

    image-20240720234319859

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mysql深入讲解(索引、事务、锁机制)

一、MySQL索引 1、何为索引? MySQL中的索引是一种数据结构,用于加快对数据库表中数据的查询速度【查询速度提升】。它类似于书本目录,使得用户可以根据特定字段快速定位到所需的数据行,而无需扫描整个表。 2、索引分类 Hash索…

C 语言回调函数

回调函数的概念 您的理解是正确的。pFunCallBack 是一种函数指针类型,它定义了函数的签名(即函数的参数类型和返回类型)。当我们说 pFunCallBack pFun,我们是在声明一个变量 pFun,其类型是 pFunCallBack —— 即一个函…

【D3.js in Action 3 精译_018】2.4 向选择集添加元素

当前内容所在位置 第一部分 D3.js 基础知识 第一章 D3.js 简介(已完结) 1.1 何为 D3.js?1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践(上)1.3 数据可视化最佳实践(下)1.4 本章小结 第二章…

RNN循环递归网络讲解与不掉包python实现

1.算法简介 参考论文:Elman J L. Finding structure in time[J]. Cognitive science, 1990, 14(2): 179-211.,谷歌被引次数超16000! 说到循环递归结构就不得不提到其鼻祖RNN网络。首先我们先对RNN有个初步的概念:想象一下,你正在…

[紧急!!!]20240719全球Windows10/11蓝屏问题,CrowdStrike导致的错误解决方案

文章目录 前言一、CrowdStrike是什么?二、PC解决方式(网路上大神的方式,虚拟机测试过)1.Windows PC 上 CrowdStrike BSOD 问题的官方解决方法:2.阻止CrowdStrick启动-命令行法3.阻止CrowdStrick启动-注册表法 三、AWS …

基于Matlab的数据可视化

基于Matlab的数据可视化 一、二维图形的绘制(一)基本图形函数(1)plot函数(2)fplot函数(3)其他坐标系的二维曲线 (二)图形属性设置(1)线…

对某次应急响应中webshell的分析

文章前言 在之前处理一起应急事件时发现攻击者在WEB应用目录下上传了webshell,但是webshell似乎使用了某种加密混淆手法,无法直观的看到其中的木马连接密码,而客户非要让我们连接webshell来证实此文件为后门文件且可执行和利用(也是很恼火&a…

数据结构与算法04二叉树|二叉排序树|AVL树

目录 一、二叉树(binary tree) 1、二叉树常见术语 2、二叉树常用的操作 2.1、初始化:与链表十分相似,先创建节点,然后构造引用/指针关系. 2.2、插入和删除操作 3、常见二叉树类型 3.1、满二叉树 3.2、完全二叉树(complete b…

跳跃游戏Ⅱ - vector

55. 跳跃游戏 - 力扣&#xff08;LeetCode&#xff09; class Solution { public:bool canJump(vector<int>& nums) {int n nums.size();int reach 0;for(int i 0; i < n; i){if(i > reach){return false;}reach max(inums[i], reach);}return true;} }; …

SpringBoot3 + Vue3 学习 Day 2

登入接口 和 获取用户详细信息的开发 学习视频登入接口的开发1、登入主逻辑2、登入认证jwt 介绍生成 JWT① 导入依赖② 编写代码③ 验证JWT 登入认证接口的实现① 导入 工具类② controller 类实现③ 存在的问题及优化① 编写拦截器② 注册拦截器③ 其他接口直接提供服务 获取用…

JVM(day4)类加载机制

类加载过程 加载 通过一个类的全限定名来获取定义此类的二进制字节流。 将这个字节流所代表的静态存储结构转化为方法区的运行时数据结构。 在内存中生成一个代表这个类的java.lang.Class对象&#xff0c;作为方法区这个类的各种数据的访问入口。 验证 文件格式验证 元数…

LeetCode做题记录(第二天)647. 回文子串

题目&#xff1a; 647. 回文子串 标签&#xff1a;双指针 字符串 动态规划 题目信息&#xff1a; 思路一&#xff1a;暴力实现 我们直接for套for分割成一个个子串再判断&#xff0c;如果子串是回文子串&#xff0c;就1&#xff0c;最后得出结果 代码实现&#xff1a; cl…

C语言实例-约瑟夫生者死者小游戏

问题&#xff1a; 30个人在一条船上&#xff0c;超载&#xff0c;需要15人下船。于是人们排成一队&#xff0c;排队的位置即为他们的编号。报数&#xff0c;从1开始&#xff0c;数到9的人下船&#xff0c;如此循环&#xff0c;直到船上仅剩15人为止&#xff0c;问都有哪些编号…

Missing script:‘dev‘

场景&#xff1a; npm run dev 原因&#xff1a;没有安装依赖&#xff0c;可用镜像安装&#xff08;详见下图ReadMe 蓝色字体&#xff09;&#xff0c;没安装依赖可从package-lock.json文件是否存在看出&#xff0c;存在则有依赖 解决&#xff1a;

KMP算法(算法篇)

算法之KMP算法 KMP算法 概念&#xff1a; KMP算法是用于解决字符串匹配的问题的算法&#xff0c;也就是有一个文本串和一个模式串&#xff0c;求解这个模式串是否在文本串中出现或者匹配。相对于暴力求解&#xff0c;KMP算法使用了前缀表来进行匹配&#xff0c;充分利用了之…

【Vue3】从零开始编写项目

【Vue3】从零开始编写项目 背景简介开发环境开发步骤及源码总结 背景 随着年龄的增长&#xff0c;很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来&#xff0c;技术出身的人总是很难放下一些执念&#xff0c;遂将这些知识整理成文&#xff0c;以纪念曾经努力学习奋斗的…

神经网络模型实现(训练、测试)

目录 一、神经网络骨架&#xff1a;二、卷积操作&#xff1a;三、卷积层&#xff1a;四、池化层&#xff1a;五、激活函数&#xff08;以ReLU为例&#xff09;&#xff1a;六、模型搭建&#xff1a;七、损失函数、梯度下降&#xff1a;八、模型保存与加载&#xff1a;九、模型训…

Linux下安装JDK、Tomact、MySQL以及Nginx的超详细步骤

目录 1、为什么安装这些软件 2、安装软件的方式 3、安装JDK 3.1 下载Linux版本的JDK 3.2 将压缩包拖拽到Linux系统下 3.3 解压jdk文件 3.4 修改文件夹名字 3.5 配置环境变量 4、安装Tomcat 4.1 下载Tomcat 4.2 将Tomcat放入Linux系统并解压&#xff0c;步骤如上面的…

MenuToolButton自绘控件,带下拉框的QToolButton,附源码

MenuToolButton自绘控件&#xff0c;带下拉框的QToolButton 效果 下拉样式可自定义 跟随QToolButton的Qt::ToolButtonStyle属性改变图标文字样式 使用示例 正常UI文件创建QToolButton然后提升&#xff0c;或者直接代码创建都可以。 // 创建一个 QList 对象来存储 QPixm…

JDK、JRE、JVM的区别java的基本数据类型

说一说JDK、JRE、JVM的区别在哪&#xff1f; JDK&#xff1a; Java Delopment kit是java工具包&#xff0c;包含了编译器javac&#xff0c;调试器&#xff08;jdb&#xff09;以及其他用于开发和调试java程序的工具。JDK是开发人员在开发java应用程序时候所需要的的基本工具。…