Pytorch使用VGG16模型进行预测猫狗二分类

news2025/1/10 10:14:54

目录

1. VGG16

1.1 VGG16 介绍

1.1.1 VGG16 网络的整体结构

 1.2 Pytorch使用VGG16进行猫狗二分类实战

1.2.1 数据集准备

1.2.2 构建VGG网络

1.2.3 训练和评估模型


 

1. VGG16

1.1 VGG16 介绍

深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任务中。VGG16是深度学习中经典的卷积神经网络(Convolutional Neural Network,CNN)之一,由牛津大学的Karen Simonyan和Andrew Zisserman在2014年提出。VGG16网络以其深度和简洁性而闻名,是图像分类中的重要里程碑。

VGG16是Visual Geometry Group的缩写,它的名字来源于提出该网络的实验室。VGG16的设计目标是通过增加网络深度来提高图像分类的性能,并展示了深度对于图像分类任务的重要性。VGG16的主要特点是将多个小尺寸的卷积核堆叠在一起,从而形成更深的网络。

1.1.1 VGG16 网络的整体结构

VGG16网络由多个卷积层和全连接层组成。它的整体结构相对简单,所有的卷积层都采用小尺寸的卷积核(通常为3x3),步幅为1,填充为1。每个卷积层后面都会跟着一个ReLU激活函数来引入非线性。

VGG16网络主要由三个部分组成:

  1. 输入层:接受图像输入,通常为224x224大小的彩色图像(RGB)。

  2. 卷积层:VGG16包含13个卷积层,其中包括五个卷积块。

  3. 全连接层:在卷积层后面是3个全连接层,用于最终的分类。

VGG16网络结构如下图:

watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70#pic_center

1、一张原始图片被resize到(224,224,3)。
2、conv1两次[3,3]卷积网络,输出的特征层为64,输出为(224,224,64),再2X2最大池化,输出net为(112,112,64)。
3、conv2两次[3,3]卷积网络,输出的特征层为128,输出net为(112,112,128),再2X2最大池化,输出net为(56,56,128)。
4、conv3三次[3,3]卷积网络,输出的特征层为256,输出net为(56,56,256),再2X2最大池化,输出net为(28,28,256)。
5、conv4三次[3,3]卷积网络,输出的特征层为512,输出net为(28,28,512),再2X2最大池化,输出net为(14,14,512)。
6、conv5三次[3,3]卷积网络,输出的特征层为512,输出net为(14,14,512),再2X2最大池化,输出net为(7,7,512)。
7、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,4096)。共进行两次。
8、利用卷积的方式模拟全连接层,效果等同,输出net为(1,1,1000)。
最后输出的就是每个类的预测。

 1.2 Pytorch使用VGG16进行猫狗二分类实战

在这一部分,我们将使用PyTorch来实现VGG16网络,用于猫狗预测的二分类任务。我们将对VGG16的网络结构进行适当的修改,以适应我们的任务。

1.2.1 数据集准备

首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。

import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

1.2.2 构建VGG网络

import torch.nn as nn

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Block 5
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 2)  # 输出层,二分类任务
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)  # 展开特征图
        x = self.classifier(x)
        return x

# 初始化VGG16模型
vgg16 = VGG16()

在上述代码中,我们定义了一个VGG16类,其中self.features部分包含了5个卷积块,self.classifier部分包含了3个全连接层。

1.2.3 训练和评估模型

import torch.optim as optim


# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10


model = VGG16()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/vgg16.pth')
# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        print(outputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Accuracy on test images: {(correct / total) * 100}%")

在训练模型时,我们使用交叉熵损失函数(CrossEntropyLoss)作为分类任务的损失函数,并采用随机梯度下降(SGD)作为优化器。同时,我们将模型移动到GPU(如果可用)来加速训练过程。

 

 


 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/836101.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

etcd 基础使用

etcd Go 操作 Etcd 参考 go get go.etcd.io/etcd/client/v3民间文档:http://www.topgoer.com/%E6%95%B0%E6%8D%AE%E5%BA%93%E6%93%8D%E4%BD%9C/go%E6%93%8D%E4%BD%9Cetcd/%E6%93%8D%E4%BD%9Cetcd.html 官方文档:https://github.com/etcd-io/etcd/blob…

如何把pdf转成cad版本?这种转换方法非常简单

将PDF转换成CAD格式的优势在于,CAD格式通常是用于工程设计和绘图的标准格式。这种格式的文件可以在计算机上进行编辑和修改,而不需要纸质副本。此外,CAD文件通常可以与其他CAD软件进行交互,从而使得工程设计和绘图过程更加高效和精…

5款无广告的超实用软件,建议收藏!

​ 大家好,我又来了,今天向大家推荐几款软件,它们有个共同的特点,就是无广告、超级实用,大家看完之后,可以自己去搜索下载试用。 1.重复文件清理——Duplicate Cleaner ​ Duplicate Cleaner是一款用于找出硬盘中重复文件并删除的工具。它可以通过内容或文件名查找重复文档、…

面试必考精华版Leetcode104. 二叉树的最大深度

题目: 代码(首刷自解 day23): class Solution { public:int maxDepth(TreeNode* root) {if(rootnullptr) return 0;return max(maxDepth(root->left),maxDepth(root->right))1;} };

安超云参与编制的《上云指导:云基础环境框架(Landing Zone)设计与应用》白皮书正式发布

近日,由中国信息通信研究院牵头,阿里云、中国移动、安超云等单位共同编制的业界首个第三方视角的《上云指导:云基础环境框架(Landing Zone)设计与应用》白皮书正式发布。 白皮书从云基础环境框架 Landing Zone 背景、内…

反诈:吴明军、黄亮领导的WIN生活资金盘,大家警惕防范此类诈骗

消息已经证实!“米粒”无法变现,数以万计的会员深套“315万民商城”,维权艰难,血汗钱无法讨回。 其实这一点笔者并不感到太意外,因为万民商城资金传销盘的定性之前就已经发文揭露过,并反复提醒大家小心警惕…

制造型企业如何实现车间设备生产数据的实时采集?需要5G网络吗?

引言 在制造业数字化转型的浪潮下,实时采集车间设备生产数据变得尤为重要。工业边缘网关HiWoo Box作为一款专为工业应用而设计的智能设备,具备工业级设计和多种联网方式,为制造型企业提供了高性能的车间设备数据实时采集解决方案。本文将重点…

ChatGPT3.5——AI人工智能是个什么玩意?

ChatGPT3.5——AI人工智能 AI人工智能什么是AI?AI有什么过人之处AI有什么缺点 AI的发展AI的发展史中国是如何发展AI的 AI六大要素感知理解推理学习交互 ChatCPT-3.5GPT-3.5的优势在哪里GPT-3.5的风险GPT-4骗人事件 AI人工智能 AI,就像是一位超级聪明的机…

Spring Boot 配置多数据源【最简单的方式】

Druid连接池 Spring Boot 配置多数据源【最简单的方式】 文章目录 Druid连接池 Spring Boot 配置多数据源【最简单的方式】 0.前言1.基础介绍2.步骤2.1. 引入依赖2.2. 配置文件2.3. 核心源码Druid数据源创建器Druid配置项 DruidConfig 3.示例项目3.1. pom3.1.1. 依赖版本定义3.…

【SpringBoot】86、SpringBoot中集成Quartz根据Cron表达式获取接下来5次执行时间

本篇文章根据集成 Quartz 根据 Cron 表达式获取接下来的 5 次执行时间,在配置定时任务时,可以清晰地知道自己的 Cron 表达式是否正确,对于 Quartz 不熟悉的同学可以先看看我之前的文章 【SpringBoot】82、SpringBoot集成Quartz实现动态管理定时任务 【SpringBoot】83、Spri…

马斯克收购AI.com域名巩固xAI公司地位;如何评估大型语言模型的性能

🦉 AI新闻 🚀 AI拍照小程序妙鸭相机上线商业工作站并邀请摄影师进行内测 摘要:AI拍照小程序妙鸭相机将上线面向商业端的工作站,并邀请摄影师进行模板设计的内测。妙鸭相机希望为行业提供更多生态产品,扩大行业规模&a…

peerDependency到底是什么

peerDependency到底是什么 正常开发中,我们经常接触到的是 package.json 中的 dependencies 和 devDependencies, 本文不对上面两个进行细节分析,让我们来看看 peerDependencies 是什么? 在 NPM v7 中,默认安装 peerDependencies…

java+springboot+mysql法律咨询网

项目介绍: 使用javassmmysql开发的法律咨询网,系统包含超级管理员,系统管理员、用户角色,功能如下: 用户:主要是前台功能使用,包括注册、登录;查看法律领域;法律法规&a…

fetch-github-hosts间隔一年大更新v2.6发布,多端支持

前言 fetch-github-hosts是一款同步 github hosts 的工具,用于帮助您解决github时而无法访问的问题。在间隔了一年之久的时间,最近抽空将fetch-github-hosts的依赖及UI进行了一波大更新,同时也增加了一些实用的功能。 主要更新 更新了基础依…

C语言 用数组名作函数参数

当用数组名作函数参数时,如果形参数组中各元素的值发生变化,实参数组元素的值随之变化。 1.数组元素做实参的情况: 如果已经定义一个函数,其原型为 void swap(int x,int y);假设函数的作用是将两个形参(x,y&#xf…

云运维工具

企业通常寻找具有成本效益的方法来优化创收,维护物理基础架构以托管服务器和应用程序以提供服务交付需要巨大的空间和前期资金,最重要的是,物理基础设施会产生额外的运营支出以进行定期维护,这对收入造成了沉重的损失。 云使企业…

降维 — PCA 真的能改善分类结果吗?

一、说明 我遇到了一些关于降维技术的资源。这个主题绝对是最有趣的主题之一,很高兴认为有一些算法能够通过选择仍然代表整个数据集的最重要的特征来减少特征的数量。作者指出的优点之一是这些算法可以改善分类任务的结果。 在这篇文章中,我将使用主成分…

在线免费做分班查询软件:这个制作分班查询系统的平台就可实现

在制作分班查询系统前,作为老师的我们可以先理清制作分班查询系统的意义!这几个要点可能很多老师都没有留意过! 提高工作效率: 学校每年都需要进行学生的分班工作,如果采用传统的手工方式进行分班查询,会…

防火墙规则分析管理

防火墙规则在高效的网络安全管理中起着至关重要的作用,在添加规则之前,确保提议的新规则不会对网络产生负面影响至关重要。 通过防火墙规则影响分析,安全管理员可以详细了解添加新规则的可能影响,防火墙规则影响分析的一个重要方…

假日购物季已经打响?卖家要提前行动起来啦!

去年,假日购物季比以往任何时候都要早开启,这促使消费者对今年的购物季抱有更多的期待。 根据Optimove最新的消费者调查,有一半的消费者(50%)计划在11月之前开始他们的假日购物。 当被问及是什么促使他们提前购买时&…