卷积神经网络--猫狗系列【CNN】

news2024/11/16 3:37:58

数据集,这次这个是分了类的【文末分享】

各12500张:

两点需要注意:

①猫狗分类是彩色图片,所以是3个channel;

②猫狗分类的图片大小不一,但是CNN的输入要求是固定大小,所以要resize。

划分训练集和测试集:

文件夹如下:

这里要注意,文件里面有三个图片是打不开的,这会导致后续运行时报错,打不开的三个图片分别是:猫猫666,猫猫10404,狗狗11702:【要么替换图片要么直接删掉(这个我选择的是替换)】

然后按照8:2来划分:(移动文件的过程,从train随机选取一些图片到test中)

import os,shutildef mymovefile(srcfile,dstfile):    if not os.path.isfile(srcfile):        print("src not exist!")    else:        fpath,fname=os.path.split(dstfile)    #分离文件名和路径        if not os.path.exists(fpath):            os.makedirs(fpath)                #创建路径        shutil.move(srcfile,dstfile)          #移动文件test_rate=0.2#训练集和测试集的比例为8:2。img_num=12500test_num=int(img_num*test_rate)import randomtest_index = random.sample(range(0, img_num), test_num)file_path=r"D:\Users\Twilight\PycharmProjects\CNN\PetImages"tr="train"te="test"cat="Cat"dog="Dog"#将上述index中的文件都移动到/test/Cat/和/test/Dog/下面去。for i in range(len(test_index)):    #移动猫    srcfile=os.path.join(file_path,tr,cat,str(test_index[i])+".jpg")    dstfile=os.path.join(file_path,te,cat,str(test_index[i])+".jpg")    mymovefile(srcfile,dstfile)    #移动狗    srcfile=os.path.join(file_path,tr,dog,str(test_index[i])+".jpg")    dstfile=os.path.join(file_path,te,dog,str(test_index[i])+".jpg")    mymovefile(srcfile,dstfile)

猫狗分类的数据集预处理:

import numpy as npfrom torchvision import transforms,datasets#定义transformstransforms = transforms.Compose([transforms.RandomResizedCrop(150),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],                            std=[0.229, 0.224, 0.225])])
train_data = datasets.ImageFolder(os.path.join(file_path,tr), transforms)test_data=datasets.ImageFolder(os.path.join(file_path,te), transforms)

RandomResizedCrop(150)是用来把图片的每一个channel大小都变成(150,150);

mean=[0.485, 0.456, 0.406],有3个数的原因是猫狗分类是彩色图片,所以有3个channel,所以每一个channel上都有一个平均值,同理,std也是;

上面的data已经把猫狗的图片都囊括了,而且标签已经自动变成了0和1。这就是ImageFolder的威力。

做一下测试:

#测试print(train_data)print(len(train_data))print(len(test_data))

print(train_data[0][0])print(train_data[0][1])

第一张[0]训练图片的具体情况:前面是图片[0],后面是标签[1]。标签0代表猫猫,1代表狗狗:因为在/train的文件夹下Cat在Dog的前面,所以前者是0,后者是1。

​print(train_data[0][0].shape)print(train_data[1][0].shape)

每一张图片都是(3,150,150),3表示3个channel。

结果分析:(都是合理的)

网络架构以及训练结果

from torch.utils import databatch_size=32train_loader = data.DataLoader(train_data,batch_size=batch_size,shuffle=True,pin_memory=True)test_loader = data.DataLoader(test_data,batch_size=batch_size)import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optim#架构会有很大的不同,因为28*28-》150*150,变化挺大的,这个步长应该快一点。class CNN(nn.Module):    def __init__(self):        super(CNN,self).__init__()        self.conv1=nn.Conv2d(3,20,5,5)#和MNIST不一样的地方,channel要改成3,步长我这里加快了,不然层数太多。        self.conv2=nn.Conv2d(20,50,4,1)        self.fc1=nn.Linear(50*6*6,200)        self.fc2=nn.Linear(200,2)#这个也不一样,因为是2分类问题。    def forward(self,x):        #x是一个batch_size的数据        #x:3*150*150        x=F.relu(self.conv1(x))        #20*30*30        x=F.max_pool2d(x,2,2)        #20*15*15        x=F.relu(self.conv2(x))        #50*12*12        x=F.max_pool2d(x,2,2)        #50*6*6        x=x.view(-1,50*6*6)        #压扁成了行向量,(1,50*6*6)        x=F.relu(self.fc1(x))        #(1,200)        x=self.fc2(x)        #(1,2)        return F.log_softmax(x,dim=1)lr=1e-4device=torch.device("cuda" if torch.cuda.is_available() else "cpu" )model=CNN().to(device)optimizer=optim.Adam(model.parameters(),lr=lr)def train(model,device,train_loader,optimizer,epoch,losses):    model.train()    for idx,(t_data,t_target) in enumerate(train_loader):        t_data,t_target=t_data.to(device),t_target.to(device)        pred=model(t_data)#batch_size*2        loss=F.nll_loss(pred,t_target)
        #Adam        optimizer.zero_grad()        loss.backward()        optimizer.step()        if idx%10==0:            print("epoch:{},iteration:{},loss:{}".format(epoch,idx,loss.item()))            losses.append(loss.item())def test(model,device,test_loader):    model.eval()    correct=0#预测对了几个。    with torch.no_grad():        for idx,(t_data,t_target) in enumerate(test_loader):            t_data,t_target=t_data.to(device),t_target.to(device)            pred=model(t_data)#batch_size*2            pred_class=pred.argmax(dim=1)#batch_size*2->batch_size*1            correct+=pred_class.eq(t_target.view_as(pred_class)).sum().item()    acc=correct/len(test_data)    # print("accuracy:{},average_loss:{}".format(acc,average_loss))    print("accuracy:{}".format(acc))num_epochs=10losses=[]from time import *begin_time=time()for epoch in range(num_epochs):    train(model,device,train_loader,optimizer,epoch,losses)# test(model,device,test_loader)end_time=time()print(test(model,device,test_loader))

​每一轮的部分截图:

精确度:

【修改网络架构,因为是彩色的,有3个channel,而且训练资源有限,增大了卷积的步长 -->self.conv1=nn.Conv2d(3,20,5,5)的最后一个参数5,即卷一次,移动5个格子,不写默认是1格。这样做,可以快速把图片缩小,原来是150*150的图片,这样可以变成30*30。】

资料分享栏目

数据集之猫狗分类(kaggle+分了类的)

链接:https://pan.baidu.com/s/1NByVZwxUk4nmCCTCetqgCA 

提取码:rmmx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/711385.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【动态规划】子数组系列(下)

子数组问题 文章目录 【动态规划】子数组系列(下)1. 等差数组划分1.1 题目解析1.2 算法原理1.2.1 状态表示1.2.2 状态转移方程1.2.3 初始化1.2.4 填表顺序1.2.5 返回值 1.3 代码实现 2. 最长湍流子数组2.1 题目解析2.2 算法原理2.2.1 状态表示2.2.2 状态…

初学spring5(五)使用注解开发

学习回顾&#xff1a;初学spring5&#xff08;四&#xff09;自动装配 一、使用注解开发 二、说明 在spring4之后&#xff0c;想要使用注解形式&#xff0c;必须得要引入aop的包 在配置文件当中&#xff0c;还得要引入一个context约束 <beans xmlns"http://www.sprin…

Node.js模块化加载机制

优先从缓存中加载 模块在第一次加载后会被缓存。这也意味着多次调用 require() 不会导致模块的代码被执行多次 注意:不论是内置模块、用户自定义模块、还是第三方模块&#xff0c;它们都会优先从缓存中加载&#xff0c;从而提高模块的加载效率 $就像下方图中测试 内置模块…

【软件测试】MySQL操作数据表常用sql语句(汇总)

目录&#xff1a;导读 前言 一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 数据表有哪些操作…

【JavaEE初阶】HTML

摄影分享~ 文章目录 一.第一个HTML程序1.创建一个HTML文件并运行2.在vscode中创建HTML文件并运行HTML代码的特点 二.HTML中的标签1.注释标签2.标题标签3.段落标签4.换行标签5.格式化标签6.图片标签&#xff1a;img7.超链接标签8.表格标签9.列表标签10.from标签input标签selec…

Tomcat概念及部署

一.Tomcat的概述 1.Tomcat介绍 &#xff08;1&#xff09;免费的、开放源代码的web应用服务器。 &#xff08;2&#xff09;主要处理的是动态页面&#xff08;做一个运行后端的程序&#xff09;可以处理静态页面&#xff0c;处理效果不及apache和nginx。 &#xff08;3&…

类变量和类方法的基本使用

什么是类变量 类变量也叫静态变量/静态属性&#xff0c;是该类的所有对象共享的变量&#xff0c;任何一个该类的对象访问它时&#xff0c;取到的都是相同的值&#xff0c;同样任何一个该类的对象去修改它时&#xff0c;修改的也是同一个变量。 如何定义类变量 定义语法&…

实战线性回归模型

引言 线性回归是一种最简单、也是最常用的预测模型&#xff0c;主要用于处理自变量和因变量之间的线性关系。举个例子&#xff0c;假设你是一名大学生&#xff0c;正在为你的经济学课程做一个研究项目&#xff0c;你想要知道大学生的学习时间和GPA&#xff08;绩点平均分&…

React 搭建DvaJS开发环境

那么 后面我们就开始将DvaJS了 他是一个特别优秀的React轻量级应用框架 他的使用了非常大 很多公司也都有在应用 他是 redux 和 redux-saga 的解决方案 可以简化操作 还内置了react-router 路由 和 fetch 网络请求 首先 它的学习并不困难 因为 Api本身其实比较少 他对于redux…

如何优化Apple搜索广告

Apple Search Ads是促进应用发展的工具&#xff0c;如果我们已经投放了广告&#xff0c;那么就要观察投放效果&#xff0c;及时识别广告薄弱的地方&#xff0c;可以给我们更多的机会去优化它并扩大其投放效果。 再开始投放广告之前&#xff0c;要确保我们的应用商店列表已经优…

vue_前后端项目分离操作-查询操作

前后端项目分离操作 使用搭建好的vue项目和ssm项目 功能需求分析 后端 查询 持久层 ​ 发送两条sql查询总条数和结果集(limit容易写死) ​ 使用分页插件pageHelper解决分页的功能 ​ 在pom.xml中添加依赖 <!--pagehelper--><dependency><groupId>com…

【C语言const关键字】

C语言const关键字 C语言之const关键字1、什么是const?2、const的用法2.1、const作常量的修饰符例程12.2、const修饰函数的参数例程2 3、const与指针变量的搭配3.1、指针与const的应用例程3.2、指针与const的应用延申二级指针 4、结束语 C语言之const关键字 前言&#xff1a; …

iOS iPadOS safari 独立Web应用屏幕旋转的时候 window.innerHeight 数值不对。

iOS iPadOS safari 独立Web应用屏幕旋转的时候 window.innerHeight 数值不对 一、问题描述 我有一个日记应用&#xff0c;是可以作为独立 Web 应用运行的那种&#xff0c;但在旋转屏幕的时候获取到的 window.innerHeight 和 window.innerWidth 就不对了&#xff0c;不是屏幕的…

无法安装此app,因为无法验证其完整性 ,解决方案

最近有很多兄弟萌跟我反应“无法安装此app,因为无法验证其完整性 ”&#xff0c;看来这个问题无法避免了&#xff0c;今天统一回复下&#xff0c;出现提示主要有以下几种可能 1.安装包不完整 首先申请我所有分享的破解软件全部都有自己校验过&#xff0c;一般不会存在问题出非你…

【视频观看记录】Bubbliiiing的Pytorch 搭建自己的Unet语义分割平台(Bubbliiiing 深度学习 教程)

来源 b站 地址 什么是语义分割 语义分割&#xff1a;对图像每个像素点进行分类 常见神经网络处理过程&#xff1a;Encoder提取特征&#xff0c;接着Docoder恢复成原图大小的图片 UNet整体结构 分为三个部分 主干特征提取部分&#xff1a; 卷积和最大池化的堆叠获得五个初…

win10安装pytorch GPU

我记得以前安装过深度学习库GPU版本&#xff0c; 需要安装cuda什么的&#xff0c;翻了下还真写过一篇win10安装tensorflow的文章&#xff0c;但是流程不止不详细&#xff0c;还不清晰。这次就再记录一遍 这次安装的是pytorch&#xff0c;这么多年似乎pytorch要逐渐统一深度学习…

JavaEE语法第二章之多线程(初阶二)

目录 一、线程常用方法 1.1启动一个线程-start() 1.2中断一个线程 1.2.1使用自定义的变量来作为标志位. 1.2.2使用 Thread.interrupted() 或者 Thread.currentThread().isInterrupted() 代替自定义标志位. 1.2.3观察标志位是否清除 1.3等待一个线程-join() 1.4获取当前…

Typora文本的使用

1. 如何创建目录&#xff1f; 输入几个#&#xff0c;再加空格&#xff0c;写入文字回车后就是几级标题&#xff1b; 2. 如何输入代码块&#xff1f; 英文状态下&#xff0c;输入三个反引号&#xff0c;然后回车即可&#xff1b; 3. 如何输入竖线和小圆点&#xff1f; 4. 如何…

SSH远程直连Docker容器

文章目录 1. 下载docker镜像2. 安装ssh服务3. 本地局域网测试4. 安装cpolar5. 配置公网访问地址6. SSH公网远程连接测试7.固定连接公网地址8. SSH固定地址连接测试8. SSH固定地址连接测试 转载自cpolar极点云文章&#xff1a;SSH远程直连Docker容器 在某些特殊需求下,我们想ssh…