卷积神经网络--猫狗系列【VGG16】

news2024/12/23 20:48:08

数据集:【文末】

数据集预处理

定义读取数据辅助类(继承torch.utils.data.Dataset)

import osimport PILimport torchimport torchvisionimport matplotlib.pyplot as pltimport torch.utils.dataimport PIL.Image
# 数据集路径train_path = './train'test_path = './test'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class MyDataset(torch.utils.data.Dataset):    def __init__(self, data_path: str, train=True, transform=None):        self.data_path = data_path        self.train_flag = train        if transform is None:            self.transform = torchvision.transforms.Compose(                [                    torchvision.transforms.Resize(size=(224, 224)),  # 尺寸规范                    torchvision.transforms.ToTensor(),  # 转化为tensor                    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化                ])        else:            self.transform = transform        self.path_list = os.listdir(data_path)  # 列出所有图片命名    def __getitem__(self, idx: int):        img_path = self.path_list[idx]        if self.train_flag is True:            # 例如 img_path 值 cat.10844.jpg -> label = 0            if img_path.split('.')[0] == 'dog':                label = 1            else:                label = 0        else:            label = int(img_path.split('.')[0])  # 获取test数据的编号        label = torch.tensor(label, dtype=torch.int64)  # 把标签转换成int64        img_path = os.path.join(self.data_path, img_path)  # 合成图片路径        img = PIL.Image.open(img_path)  # 读取图片        img = self.transform(img)  # 把图片转换成tensor        return img, label    def __len__(self) -> int:        return len(self.path_list)  # 返回图片数量train_datas = MyDataset(train_path)test_datas = MyDataset(test_path, train=False)

(原本数据有25000张,由于设备的原因,训练完之后我删掉了很多图片,训练集+测试集只有2000张)

查看读取的数据

# 展示读取的图片数据,因为做了归一化,所有图片显示不正常。Img_PIL_Tensor = train_datas[20][0]new_img_PIL = torchvision.transforms.ToPILImage()(Img_PIL_Tensor).convert('RGB')plt.imshow(new_img_PIL)plt.show(block=True)

训练集和测试集分组,数据分batch

(根据自己的设备来,好的就设32,不好就4吧)

# 70%训练集  30%测试集train_size = int(0.7 * len(train_datas))validate_size = len(train_datas) - train_sizetrain_datas,validate_datas = torch.utils.data.random_split(train_datas,[train_size, validate_size])# 数据分批# batch_size=32 每一个batch大小为32# shuffle=True 打乱分组# pin_memory=True 锁页内存,数据不会因内存不足,交换到虚拟内存中,能加快数据读入到GPU显存中.# num_workers 线程数。num_worker设置越大,加载batch就会很快,训练迭代结束可能下一轮batch已经加载好# win10 设置会多线程可能会出现问题,一般设置0.train_loader = torch.utils.data.DataLoader(train_datas, batch_size=4,                                            shuffle=True, pin_memory=True, num_workers=0)validate_loader = torch.utils.data.DataLoader(validate_datas, batch_size=4,                                            shuffle=True, pin_memory=True, num_workers=0)test_loader = torch.utils.data.DataLoader(test_datas, batch_size=4,                                            shuffle=False, pin_memory=True, num_workers=0)

VGG网络:

def vgg_block(num_convs, in_channels, out_channels):    layers = []    for _ in range(num_convs):        layers.append(torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))        layers.append(torch.nn.ReLU())        in_channels = out_channels    # ceil_mode=False 输入的形状不是kernel_size的倍数,直接不要。    # ceil_mode=True 输入的形状不是kernel_size的倍数,单独计算。    layers.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False))    return torch.nn.Sequential(*layers)def vgg(conv_arch):    conv_blks = []    # 数据输入是几个通道    in_channels = 3    # 卷积层部分    for (num_convs, out_channels) in conv_arch:        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))        in_channels = out_channels    return torch.nn.Sequential(        *conv_blks, torch.nn.Flatten(),        torch.nn.Linear(out_channels * 7 * 7, 4096), torch.nn.ReLU(), torch.nn.Dropout(0.5),        torch.nn.Linear(4096, 4096), torch.nn.ReLU(), torch.nn.Dropout(0.5),        torch.nn.Linear(4096, 2))

VGG神经网络定义和参数初始化

# VGG11,VGG13,VGG16,VGG19 可自行更换。conv_arch = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))  # vgg16#conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))  # vgg11#conv_arch = ((2, 64), (2, 128), (2 , 256), (2, 512), (2, 512))  # vgg13#conv_arch = ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512))  # vgg19net = vgg(conv_arch)   # 定义网络net = net.to(device)   # 把网络加载到GPU上# Xavier方法 初始化网络参数,最开始没有初始化一直训练不起来。def init_normal(m):    if type(m) == torch.nn.Linear:        # Xavier初始化        torch.nn.init.xavier_uniform_(m.weight)        torch.nn.init.zeros_(m.bias)    if type(m) == torch.nn.Conv2d:        # Xavier初始化        torch.nn.init.xavier_uniform_(m.weight)        torch.nn.init.zeros_(m.bias)net.apply(init_normal)learn_rate = 1e-5#momentum = 0.9#optimizer = torch.optim.SGD(net.parameters(), learn_rate, momentum = momentum) #定义梯度优化算法optimizer = torch.optim.Adam(net.parameters(), learn_rate) #开始使用SGD没有训练起来,才更换的Adamcost = torch.nn.CrossEntropyLoss(reduction='sum')     # 定义损失函数,返回batch的loss和。print(net)    # 打印模型架构

训练VGG神经网络

epoch = 10  # 迭代10次def train_model(net, train_loader, validate_loader, cost, optimezer):    net.train()  # 训练模式    now_loss = 1e9  # flag 计算当前最优loss    train_ls = []  # 记录在训练集上每个epoch的loss的变化情况    train_acc = []  # 记录在训练集上每个epoch的准确率的变化情况    for i in range(epoch):        loss_epoch = 0.  # 保存当前epoch的loss和        correct_epoch = 0  # 保存当前epoch的正确个数和        for j, (data, label) in enumerate(train_loader):            data, label = data.to(device), label.to(device)            pre = net(data)            # 计算当前batch预测正确个数            correct_epoch += torch.sum(pre.argmax(dim=1).view(-1) == label.view(-1)).item()            loss = cost(pre, label)            loss_epoch += loss.item()            optimezer.zero_grad()            loss.backward()            optimezer.step()            if j % 100 == 0:                print(                    f'batch_loss:{loss.item()}, batch_acc:{torch.sum(pre.argmax(dim=1).view(-1) == label.view(-1)).item() / len(label)}%')        train_ls.append(loss_epoch / train_size)        train_acc.append(correct_epoch / train_size)        # 每一个epoch结束后,在验证集上验证实验结果。        with torch.no_grad():            loss_validate = 0.            correct_validate = 0            for j, (data, label) in enumerate(validate_loader):                data, label = data.to(device), label.to(device)                pre = net(data)                correct_validate += torch.sum(pre.argmax(dim=1).view(-1) == label.view(-1)).item()                loss = cost(pre, label)                loss_validate += loss.item()            # print(f'validate_sum:{loss_validate},  validate_Acc:{correct_validate}')            print(f'validate_Loss:{loss_validate / validate_size},  validate_Acc:{correct_validate / validate_size}%')            # 保存当前最优模型参数            if now_loss > loss_validate:                now_loss = loss_validate                print("保存模型参数。。。。。。。。。。。")                torch.save(net.state_dict(), 'model.params')    # 画图    plt.plot(range(epoch), train_ls, color='b', label='loss')    plt.plot(range(epoch), train_acc, color='g', label='acc')    plt.legend()    plt.show(block=True)  # 显示 labletrain_model(net, train_loader, validate_loader, cost, optimizer)

资料分享栏目

数据集之猫狗系列(VGG16)

链接:https://pan.baidu.com/s/1MoJPs-BQ6GP1PrXjo-wKsQ

提取码:dgna

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/714911.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哈希桶的增删查改简单实现

个人简单笔记。 目录 闭散列 开散列 插入 删除 查找 改变 什么是哈希桶呢?这是一个解决哈希数据结构的一种解决方法,在STL中的unorder_map与unorder_set的底层结构就是使用它来实现的。 闭散列 首先我们知道,哈希映射表是依据数组下…

CSS画特殊边框

例如如图所示边框 .card-middle {width: 672px;height: 486px;border: 1px solid #5fadec;border-radius: 5px;position: relative; }.card-middle::before {content: ;position: absolute;top: -4px;left: -4px;width: 680px;height: 448px;border: 25px solid transparent;b…

【Python】PIL.Image转QPixmap后运行异常的个人解决方法

问题场景: PIL.Image图片,直接调用PIL.Image.toqpixmap()转成QPixmap后,不会立即报错,   但后续使用该QPixmap时(包括但不仅限于使用QLabel.setPximap()、QPixmap.save())将立即出现异常 不知道是我关键词不对,还是只…

【数据结构与算法】文学语言助手(C\C++)

实践要求 1. 问题描述 文学研究人员需要统计某篇英文小说中某些形容词的出现次数和位置。试写一个实现这一目标的文字统计系统,称为"文学研究助手"。 2. 基本要求 英文小说存于文本文件中。待统计的词汇集合要一次输入完毕,即统计工作必需在…

linux常用命令介绍 06 篇——Linux查看目录层级结构以及创建不同情况的层级目录

linux常用命令介绍 06 篇——Linux查看目录层级结构以及创建不同情况的层级目录 1. 前言1.1 Linux常用命令其他篇1.2 关于tree简介 2. 安装并使用 tree2.1 安装tree2.1.1 方式1:yum安装2.1.2 方式2:下载安装包安装2.1.2.1 下载安装包2.1.2.2 解压安装2.1…

transformer入坑指南

*免责声明: 1\此方法仅提供参考 2\搬了其他博主的操作方法,以贴上路径. 3* 场景一: Attention is all you need 场景二: VIT 场景三: Swin v1 场景四: Swin v2 场景五: SETR 场景六: TransUNet 场景七: SegFormer 场景八: PVT 场景九: Segmeter … 场景一:Attention…

Spring Boot 中的 Spring Cloud Ribbon:什么是它,原理及如何使用

Spring Boot 中的 Spring Cloud Ribbon:什么是它,原理及如何使用 在分布式系统中,服务之间的通信是非常重要的。在大型的分布式系统中,有许多服务需要相互通信,而这些服务可能会部署在多个服务器上。为了实现服务之间…

超详细Redis入门教程——Redis分布式系统

前言 本文小新为大家带来 Redis分布式系统 相关知识,具体内容包括数据分区算法(包括:顺序分区,哈希分区),系统搭建与运行(包括:系统搭建,系统启动与关闭)&…

把 OpenGrok search 上的Android 开源代码扒下来

1、下载工具 wget (window10版本)以及配置环境变量 工具我会上传到本篇博客的“代码包”区域,可以自行下载! 当然如果可以访问如下链接的话,也可以在这个地址自行下载一个比较新的版本即可!GNU Wget 1.21.…

Web服务器群集:LVS+Keepalived高可用群集

目录 一、理论 1.Keepalived 2.VRRP协议(虚拟路由冗余协议) 3.部署LVSKeepalived 高可用群集 二、实验 1.LVSKeepalived 高可用群集 三、问题 1.备服务器网卡启动报错 四、总结 一、理论 1.Keepalived (1)简介 Keepal…

【动态规划算法】-第一题:1137.第N个斐波那契数

💖作者:小树苗渴望变成参天大树 🎉作者宣言:认真写好每一篇博客 🎊作者gitee:gitee 如 果 你 喜 欢 作 者 的 文 章 ,就 给 作 者 点 点 关 注 吧! 文章目录 前言 前言 各位友友们&#xff0c…

element之el-table合并列功能

目标效果如下&#xff1a; 实现代码如下&#xff1a; html部分&#xff1a; <!--定义表格组件,用组件自带的span-method属性定义合并列的方法--> <el-table :data"tableData" :span-method"spanRow"><el-table-column prop"RegionNa…

在proteus中仿真arduino驱动点阵屏matrix-led

我们都知道&#xff0c;如果我们仅仅在某个时间段点亮一个数码管是没有任何困难的&#xff0c;但如果我们点亮多个数码管就会出现问题&#xff0c;因为多个数码管都使用着同样的端口来控制数码管的各个段的亮灭。所以&#xff0c;就会用上一个很重要的方法&#xff0c;对&#…

使用javaScript脚本生成openFoam网格

简介 OpenFoam的首选网格生成器是blockMesh。blockMesh可以根据blockMeshDict这个字典中的信息生成openFoam网格。但是有时候需要修改网格&#xff0c;而网格中的几何点之间又存在约束关系&#xff0c;如果手动修改blockMeshDict那么工作量将是巨大的&#xff0c;所以有必要使…

有没有免费提取音频的软件,分享几个给大家!

在日常生活中&#xff0c;我们经常遇到需要从视频中提取音频的情况&#xff0c;无论是为了制作音频片段、录制语音笔记还是进行后期编辑。本文将介绍三种免费提取音频的方法&#xff0c;分别是记灵在线工具、PR&#xff08;Adobe Premiere Pro&#xff09;和剪映。通过这些方法…

【Vue3】学习笔记-自定义hook函数

概念 什么是hook? 本质是一个函数&#xff0c;把setup函数中使用的Composition API进行了封装。 类似于vue2.x中的mixin。(但是mixins会组件的配置项覆盖。vue3使用了自定义hooks替代mixnins&#xff0c;hooks本质上是函数&#xff0c;引入调用。) 自定义hook的优势: 复用代…

PPU (power policy unit)

写在前边 最近在做低功耗验证&#xff0c;项目中涉及到PPU这一块儿&#xff0c;在家查了好久资料&#xff0c;发现能找到的有价值的文章真的好少&#xff0c;机缘巧合之下&#xff0c;让我找到下边总结&#xff0c;分享出来&#xff0c;希望对和我有相同境遇的小伙伴带来帮助&a…

每周学点数学 2:概率论基础1

泊松分布、正态分布、二项分布 文章目录 1.概率论学习中的重难点2.主要工具介绍1. Python2. MATLAB3. R4. Octave5. Microsoft Excel6. 统计软件 3.理论内容概览&#xff08;前两点&#xff09;1. 概率2. 概率分布 注&#xff1a;本文适用于在在数学建模的应用中&#xff0c;回…

牛客网基础语法101~110题

牛客网基础语法101~110题&#x1f618;&#x1f618;&#x1f618; &#x1f4ab;前言&#xff1a;今天是咱们第十期刷牛客网上的题目。 &#x1f4ab;目标&#xff1a;对打印图案做到有手就行。 &#x1f4ab;鸡汤&#xff1a;与其花时间应付以后不理想的生活&#xff0c;不如…

学习c++ Part02

学习c Part02 前言1.函数注意点&#xff1a;全局函数&#xff08;默认函数&#xff09;静态函数 2.预处理2.1 变量 3.头文件4.宏函数5.指针5.1 普通变量与指针变量建立关系&#xff1a;5.2 指针初始化5.3 指针变量的注意事项5.3.1 void 不能定义普通变量,void * 可以定义指针变…