CNN入门实战:猫狗分类

news2024/12/22 16:56:08

前言

        CNN(Convolutional Neural Network,卷积神经网络)是一种深度学习模型,特别适用于处理图像数据。它通过多层卷积和池化层来提取图像的特征,并通过全连接层进行分类或回归等任务。CNN在图像识别、目标检测、图像分割等领域取得了很大的成功。

CNN网络结构

        目标分类是指识别图像中的物体,并将其归类到不同的类别中。例如,猫狗分类就是一个目标分类的任务,CNN可以帮助我们构建一个模型来自动识别图像中的猫和狗。

如何入门CNN

要入门CNN,可以先了解深度学习的基本概念和原理,然后学习如何构建和训练CNN模型。可以选择一些经典的教材、在线课程或者教程来学习深度学习和CNN的基础知识。

实战案例分析

以下是一个简单的使用PyTorch构建CNN模型的示例代码:

1、导包

from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,  random_split, Dataset
from torchvision import datasets, models

import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

2、设置数据集目录

# Setting the data directories/ paths

data_pth = '/你的数据集目录'
cats_dir = data_pth + '/Cat'
dogs_dir = data_pth + '/Dog'

3、打印图像的数量

print("Total Cats Images:", len(os.listdir(cats_dir)))
print("Total Dogs Images:", len(os.listdir(dogs_dir)))
print("Total Images:", len(os.listdir(cats_dir)) + len(os.listdir(dogs_dir)))

4、查看Cat数据

cat_img = Image.open(cats_dir + '/'  + os.listdir(cats_dir)[0])

print('Shape of cat image:', cat_img.size)
cat_img

5、查看Dog的数据

dog_img = Image.open(dogs_dir + '/'  + os.listdir(dogs_dir)[0])

print('Shape of dog image:', dog_img.size)
dog_img

6、自定义加载数据集方法

class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        # Initialize your dataset here
        self.data = data
        self.transform = transform

    def __len__(self):
        # Return the number of samples in your dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Implement how to get a sample at the given index
        sample = self.data[idx]
        
        try:
            img = Image.open(data_pth + '/Cat/' + sample)
            label = 0
            
        except:
            img = Image.open(data_pth + '/Dog/' + sample)
            label = 1

        # Apply any transformations (e.g., preprocessing)
        
        if self.transform:
            img = self.transform(img)

        return img, label

7、定义数据转换(调整大小、规格化、转换为张量等)

在训练目标分类模型时,我们通常会使用转换数据来对输入数据进行预处理,以便更好地适应模型的训练和提高模型的性能。

使用转换数据的原因包括:

  1. 调整大小:输入数据通常具有不同的尺寸和分辨率,为了确保模型能够处理这些不同尺寸的数据,我们需要将其调整为统一的大小。这样可以确保模型在训练和预测时能够处理相同大小的输入数据。

  2. 规格化:规格化是指将输入数据的数值范围调整到相似的范围,以便更好地适应模型的训练。规格化可以帮助模型更快地收敛,提高模型的稳定性和准确性。

  3. 转换为张量:在深度学习中,输入数据通常需要转换为张量形式,以便与神经网络模型进行计算。因此,我们需要将输入数据转换为张量形式,以便能够输入到模型中进行训练和预测。

总之,转换数据是为了确保模型能够更好地适应输入数据,并提高模型的性能和准确性。通过调整大小、规格化和转换为张量等操作,我们可以更好地准备输入数据,使其更适合用于训练目标分类模型。

# Define the data transformation (resize, normalize, convert to tensor, etc.)
    
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a fixed size (adjust as needed)
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),           # Convert images to PyTorch tensors
])

data = [i for i in os.listdir(data_pth + '/Cat') if i.endswith('.jpg')] + [i for i in os.listdir(data_pth + '/Dog') if i.endswith('.jpg')]

combined_dataset = CustomDataset(data_path=data, transform=transform)
#dataloader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=False)

8、定义拆分比例(例如,80%用于培训,20%用于测试)

# Define the ratio for splitting (e.g., 80% for training, 20% for testing)
train_ratio = 0.8
test_ratio = 1.0 - train_ratio

# Calculate the number of samples for training and testing
num_samples = len(combined_dataset)
num_train_samples = int(train_ratio * num_samples)
num_test_samples = num_samples - num_train_samples

# Use random_split to split the dataset
train_dataset, test_dataset = random_split(combined_dataset, [num_train_samples, num_test_samples])

# Create data loaders for training and testing datasets
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

9、自定义CNN模型

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=3, stride=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(288, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)

        return x

10、GPU是否可用

# check if gpu is available or not

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

11、初始化模型、损失函数和优化器

     训练过程中的关键步骤,其重要性如下:

  1. 初始化模型:模型的初始化是指对模型参数进行初始赋值。正确的初始化可以加速模型的收敛,提高训练的效率和稳定性。如果模型参数的初始值过大或过小,可能会导致梯度爆炸或梯度消失,从而影响模型的训练效果。

  2. 损失函数:损失函数是用来衡量模型预测结果与真实标签之间的差距。选择合适的损失函数可以帮助模型更好地学习数据的特征,并且在训练过程中不断优化模型参数,使得损失函数值逐渐减小。

  3. 优化器:优化器是用来更新模型参数的算法,常见的优化器包括随机梯度下降(SGD)、Adam、RMSprop等。选择合适的优化器可以加速模型的收敛,提高训练的效率和稳定性。不同的优化器有不同的更新规则,可以根据具体的任务和数据特点选择合适的优化器。

因此,初始化模型、损失函数和优化器是CNN训练过程中的关键步骤,它们的选择和设置会直接影响模型的训练效果和性能。

# Initialize the model, loss function, and optimizer
net = CNN().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

12 、开始训练数据

epochs = 5
net.train()
for epoch in range(epochs):
    running_loss = 0.0
    for idx, (inputs, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):

        inputs = inputs.to(device)
        labels = labels.to(device).to(torch.float32)

        optimizer.zero_grad()

        outputs = net(inputs).reshape(-1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch: {epoch + 1}, Loss: {running_loss}')

print('Training Finished!')

13、验证数据

net.eval()  # Set the model to evaluation mode
correct = 0
total = 0

with torch.no_grad():
    for idx, (inputs, labels) in tqdm(enumerate(test_loader), total=len(test_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device).to(torch.float32)

        outputs = net(inputs).reshape(-1)
        predicted = (outputs > 0.5).float()  # Assuming a binary classification threshold of 0.5

        correct += (predicted == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total if total > 0 else 0.0
print(f'Test Accuracy: {accuracy:.2%}')

14、测试数据

label_names = ['cat', 'dog']
fig, ax = plt.subplots(1, 5, figsize=(15, 5))
outputs = outputs.cpu()
inputs = inputs.cpu()
labels = labels.cpu()

for i in range(5):
    ax[i].imshow(inputs[i].permute(1,2,0))
    ax[i].set_title(f'True: {label_names[labels[i].to(int)]}, Pred: {label_names[torch.where(outputs[i] > 0.5, 1, 0).item()]}')
    ax[i].axis(False)

plt.show()

至此,一个CNN训练模型从搭建到测试的完整实现过程就完成了,我们使用了PyTorch构建了一个简单的CNN模型,并使用猫狗分类的训练数据对模型进行训练。首先准备了训练数据集,然后构建了一个简单的CNN模型,定义了损失函数和优化器,最后进行了模型的训练。

数据集训练完整代码

from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,  random_split, Dataset
from torchvision import datasets, models

import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import os
from tqdm import tqdm



data_pth = '/你的数据集目录'
cats_dir = data_pth + '/Cat'
dogs_dir = data_pth + '/Dog'

cat_img = Image.open(cats_dir + '/'  + os.listdir(cats_dir)[0])
dog_img = Image.open(dogs_dir + '/'  + os.listdir(dogs_dir)[0])

#定义加载数据集方法
class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        # Initialize your dataset here
        self.data = data
        self.transform = transform

    def __len__(self):
        # 返回数据集数量
        return len(self.data)

    def __getitem__(self, idx):
        # 获取数据集对应的下标
        sample = self.data[idx]
        
        try:
            img = Image.open(data_pth + '/Cat/' + sample)
            label = 0
            
        except:
            img = Image.open(data_pth + '/Dog/' + sample)
            label = 1

       
        if self.transform:
            img = self.transform(img)

        return img, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a fixed size (adjust as needed)
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),           # Convert images to PyTorch tensors
])

data = [i for i in os.listdir(data_pth + '/Cat') if i.endswith('.jpg')] + [i for i in os.listdir(data_pth + '/Dog') if i.endswith('.jpg')]

combined_dataset = CustomDataset(data_path=data, transform=transform)


#划分数据集
train_ratio = 0.8
test_ratio = 1.0 - train_ratio

# Calculate the number of samples for training and testing
num_samples = len(combined_dataset)
num_train_samples = int(train_ratio * num_samples)
num_test_samples = num_samples - num_train_samples

# Use random_split to split the dataset
train_dataset, test_dataset = random_split(combined_dataset, [num_train_samples, num_test_samples])

# Create data loaders for training and testing datasets
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


#自定义CNN模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=3, stride=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(288, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)

        return x

#GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#初始化模型、损失函数和优化器
net = CNN().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#开始训练
epochs = 5
net.train()
for epoch in range(epochs):
    running_loss = 0.0
    for idx, (inputs, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):

        inputs = inputs.to(device)
        labels = labels.to(device).to(torch.float32)

        optimizer.zero_grad()

        outputs = net(inputs).reshape(-1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch: {epoch + 1}, Loss: {running_loss}')

print('Training Finished!')

数据集下载

百度网盘:https://pan.baidu.com/s/1CjTNLGvBBDxmKEADN3SNWw?pwd=o37e 
提取码:o37e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1200005.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu下tensorrt环境配置

文章目录 一、Ubuntu18.04环境配置1.1 安装工具链和opencv1.2 安装Nvidia相关库1.2.1 安装Nvidia显卡驱动1.2.2 安装 cuda11.31.2.3 安装 cudnn8.21.2.4 下载 tensorrt8.4.2.4 二、编写CMakeLists.txt三、TensorRT系列教程 一、Ubuntu18.04环境配置 教程同样适用与ubuntu22.04…

springcloud电影购票选座网站系统源码

开发技术: jdk1.8,mysql5.7,idea springcloud springboot mybatis vue elementui 功能介绍: 用户端: 登录注册 首页显示搜索电影,轮播图,电影分类,最近上架电影(可…

Linux 内核启动流程

目录 链接脚本vmlinux.ldsLinux 内核启动流程分析Linux 内核入口stext__mmap_switched 函数start_kernel 函数rest_init 函数init 进程 看完Linux 内核的顶层 Makefile 以后再来看 Linux 内核的大致启动流程,Linux 内核的启动流程要比uboot 复杂的多,涉及…

第六章 DNS域名解析服务器

1、DNS简介 DNS(Domain Name System)是互联网上的一项服务,它作为将域名和IP地址相互映射的一个分布式数据库,能够使人更方便的访问互联网。 DNS系统使用的是网络的查询,那么自然需要有监听的port。DNS使用的是53端口…

思科9300交换机使用USB进行升级ISO

一、下载ISO 一、网址 Software Download - Cisco Systems 二、找到型号 四、选择XE 软件 五、进行下载 二、COPY 进 U盘 一、、请注意!如果你的U盘不是Fat32文件格式则交换机读取不了,请先格式化再复制文件。 二、下载后将 bin文件复制到U盘。 1.扩展…

js删除json数据中指定元素

delete 删除数组方法: function removeJSONRows() {var tab {"dataRows": [{"id": 1,"name": "使用部门"},{"id": 2,"name": "车辆走行路线"},{"id": 3,"name": &quo…

【Redis】String字符串类型

上一篇:Redis-key的使用 https://blog.csdn.net/m0_67930426/article/details/134361821?spm1001 .2014.3001.5501 目录 appen (附加) strlen(获取字符串的长度) incr decr getRange(获取字符串) setRange(替…

C语言--求一个 3 X 3 的整型矩阵对角线元素之和

一.题目描述 求一个 3 X 3 的整型矩阵对角线元素之和 二.代码实现 #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> int main() {int arr[3][3] { 0 };for (int i 0;i < 3;i){for (int j 0;j < 3;j){ printf("请输入数字&#xff1a;");scanf(&…

卸载本地开发环境,拥抱容器化开发

以前在公司的时候&#xff0c;使用同事准备的容器化环境&#xff0c;直接在 Docker 内进行开发&#xff0c;爽歪歪呀。也是在那时了解了容器化开发的知识&#xff0c;可惜了&#xff0c;现在用不到那种环境了。所以打算自己在本地也整一个个人的开发环境&#xff0c;不过因为我…

SMART PLC MODBUSTCP速度测试

SMART PLC MODBUSTCP通信详细介绍请参看下面文章链接: S7-200SMART PLC ModbusTCP通信(多服务器多从站轮询)_matlab sumilink 多个modbustcp读写_RXXW_Dor的博客-CSDN博客文章浏览阅读6.4k次,点赞5次,收藏10次。MBUS_CLIENT作为MODBUS TCP客户端通过S7-200 SMART CPU上的…

【python】sys-psth和模块搜索路径

我们在导入一个模块的时候&#xff0c;比如说&#xff1a; import math它必然是有搜索路径的&#xff0c;那到底是在哪个目录下面找呢&#xff1f;Python解释器去哪里找这个文件呢&#xff1f;只有找到这个文件才能读取、装载运行该模块文件。 它一般按照如下路径寻找模块文件…

经典OJ题:重排链表

题目&#xff1a; 给定一个链表&#xff0c;在进行重排前&#xff1a; 进行重排链表后&#xff1a; 如上图所示&#xff0c;所谓的重拍链表&#xff0c;就是将第一个节点连接第倒数第一个节点&#xff0c;第二个节点连接倒数第二个节点&#xff0c;以此类推&#xff0c;最后在连…

贝锐蒲公英X1解决远程访问NAS难题

由于经常在外出差和旅游&#xff0c;需要实现即使在外地也能远程登录回去家里的NAS去处理事情或传输文件&#xff0c;因此解决方案之一是搭建一个安全简易的个人私有云。 实施难度 &#xff08;1&#xff09;家庭网络无公网IP&#xff0c;且公网IP价格昂贵&#xff08;2&…

今起不再“没完没了的接龙斗嘴”

今天本“人民体验官”推广人民日报官方微博&#xff08;转央视网&#xff09;的文化产品《数字减负不能比减脂还难》。 截图&#xff1a;来源“人民体验官”推广平台 在时下的一些网络自媒体平台之上&#xff0c;的确存在“越拉越多的群&#xff0c;没完没了的接龙&#xff0c…

在使用Vuex时,5个方法让你保证数据的更新及时性

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

hosts文件修改完成之后无法保存的解决方法

系列文章目录 centos7配置静态网络常见问题归纳_张小鱼༒的博客-CSDN博客 目录 系列文章目录 前言 一、hosts文件为何不能保存的原因 二、Hosts文件无法保存解决方法 1.需要用到hosts的地方 2.具体的操作步骤 总结 前言 Hosts文件是系统中的重要文件&#xff0c;它能屏…

Spring面试题:(五)Spring注解开发@Component,@Autowired,@Bean,@Configuration

Bean基本注解 spring提供注解的版本 Component注解替代bean标签 bean其它属性的相关注解&#xff1a; scope 替代scopelazy 替代lazy-initPostConstruct 替代init-methodPreDestroy 替代destroy-method 使用Component注解的前提是开启注解扫描 衍生注解Repository,Servi…

博客积分上一万一千了

博客积分上一万一千了 充满自信&#xff0c;继续前进。

GCC工具详解【Linux知识贩卖机】

很多人在喧嚣声中登场&#xff0c;也有少数人在静默中退出。 --单独中的洞见2 文章目录 简介程序到可执行文件链接动态链接和静态链接动态库和静态库动态库和静态库的打包打包静态库打包动态库选项 -static 总结 简介 GCC&#xff08;GNU Compiler Collection&#xff09; 是一…

移动硬盘和u盘的区别哪个好 移动硬盘和u盘有啥区别

在数字时代的今天&#xff0c;数据存储已经成为我们生活中的重要一环。当我们需要移动、备份或传输大量数据时&#xff0c;常常会不知道是选择移动硬盘还是U盘。其实&#xff0c;对于许多人来说&#xff0c;移动硬盘和U盘之间的区别并不清晰。下面我们就来看移动硬盘和u盘的区别…