PyTorch学习(2)-猫狗大战数据集分类识别-PyTorch代码实训

news2024/9/20 0:19:13

猫狗大战数据集分类识别-PyTorch代码实训

二分类任务

数据集文件目录结构图

pythonProject/
│
├── cat_recognition.py
│  
└── kagglecatsanddogs_5340/
    └── PetImages/
    	├── Cat/...
    	└── Dog/...

Cat和Dog文件夹中的图片的后缀均为.jpg

代码1(实现二分类问题)

import torch
import torch.nn
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torchvision.datasets
from torchvision import transforms, models
from PIL import Image
from torch import nn, optim
import matplotlib.pyplot as plt

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])

root_dir = './kagglecatsanddogs_5340/PetImages'

dataset = torchvision.datasets.ImageFolder(root=root_dir, transform=transform)

n_train = int(0.8 * len(dataset))
n_test = len(dataset) - n_train

train_dataset, test_dataset = random_split(dataset, [n_train, n_test])

# dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)

print("Finish Reading the Dataset")

# MyNet


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(in_features=32*56*56, out_features=512)
        self.relu_fc1 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=512, out_features=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu_fc1(x)
        x = self.fc2(x)
        return x


model = MyNet()


# model = models.resnet50(pretrained=True)
# for param in model.parameters():
#     param.requires_grad = False
# model.fc = nn.Linear(2048, 2)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)

# train


def train(epoch): # epoch: 方便打印
    running_loss = 0.0
    running_total = 0
    running_correct = 0
    for batch_idx, data in enumerate(train_loader, 0): # 给train_loader元素编号,从0开始
        inputs, targets = data # inputs和targets是“数组”的形式
        optimizer.zero_grad() # 消除优化器中原有的梯度

        outputs = model(inputs)
        loss = criterion(outputs, targets) # 对比输出结果和“答案”

        loss.backward()
        optimizer.step() # 优化网络参数

        running_loss += loss.item() # .item(): 取出tensor中特定位置的具体元素值并返回该值(Tensor to int or float)
        _, predicted = torch.max(outputs.data, dim=1) # 找到每个样本预测概率最高的类别的标签值(即预测结果)
        # dim=0计算tensor中每列的最大值的索引,dim=1表示每行的最大值的索引
        running_total += inputs.shape[0] # .shape[0]: 读取矩阵第一维度的长度
        running_correct += (predicted == targets).sum().item()

        if batch_idx % 300 == 299:
            print('[%d, %5d]: loss: %.3f , acc: %.2f %%'
                  % (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))
            running_loss = 0.0
            running_total = 0
            running_correct = 0

# test *
def test(epoch):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = correct / total
    print('[%d / %d]: Accuracy on test set: %.1f %% ' % (epoch + 1, 3, 100 * acc))
    return acc



# main

for epoch in range(3):
    train(epoch)
    acc_test = test(epoch)
    acc_list_test.append(acc_test)

print("----------Finish the model training process.----------")

结果1

(之前调整过epoch,且由于电脑性能限制,完成过多epoch的训练耗时较长,故仅进行了几轮训练)

在这里插入图片描述

代码2(用于理解dataloader内的结构)

# 理解dataloader

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])

# 加载数据集
dataset = ImageFolder('./kagglecatsanddogs_5340/PetImages', transform=transform)

# 创建 DataLoader
loader = DataLoader(dataset, batch_size=4)

# 通过迭代的方式访问 DataLoader 中的元素
for i, (images, labels) in enumerate(loader):
    # if i == 0:  # 仅显示第一个批次的数据
        print(f"第{i + 1}个批次的图像张量:")
        print(images.shape)  # 显示图像张量的形状
        print("对应的标签:", labels)
    # break


# 获取数据集中的特定样本(假设索引为10)
sample_idx = 10
image, label = dataset[sample_idx]
print(f"索引 {sample_idx} 的图像张量:")
print(image.shape)
print("对应的标签:", label)

结果2

在这里插入图片描述

ImageFolder方法会自动读取文件目录中的图片,并将其打上标签:[‘0’, ‘1’]。

问题

loss的含义;loss值的大小怎么看?

损失函数用来优化模型。在此二分类任务中,采用了交叉熵损失函数。损失值越小,表示模型对类别预测的准确性越高。

在训练过程中,目标是通过优化算法使损失值逐渐减小。

较大的损失值可能表示模型尚未充分学习数据的特征,或者模型架构、超参数选择不合适。在训练过程中,需要检查和调整模型,以使损失值逐渐减小。

batch_size对于训练速度和训练模型分类正确率的影响?

增大batch_size可以减少迭代次数。对相同的数据量,处理速度比小的batch_size更快。

但过大的batch_size可能会让内存容量撑不住,同时对参数的修正会变缓。

总结:

  1. batch_size设的大一些,收敛得快,也就是需要训练的次数少,准确率上升的也很稳定,但是实际使用起来精度不高;
  2. batch_size设的小一些,收敛得慢,可能准确率来回震荡,因此需要把基础学习速率降低一些,但是实际使用起来精度较高。

如何设计神经网络的结构?

设计神经网络的思路:

  • 先设计一个过拟合的模型
  • 再消除过拟合带来的问题

对于datasets.[数据集名]的参数transform,.ToTensor()和.Resize()谁先谁后?有影响吗?

推荐的操作顺序是先进行Resize操作,然后再进行ToTensor操作。

如果先ToTensor后Resize,那么最终得到的张量尺寸是转换前的原始大小。因为调整大小操作通常需要基于图像的像素信息来进行,而且模型通常要输入固定大小的张量。

怎样认识CNN提取了猫狗图像中的哪些特征用于分类任务?

一种方法是打印feature map来可视化网络提取的特征。

当batch_size = 16时,对搭建的卷积网络中的几个层进行可视化:

  • conv1

在这里插入图片描述

  • maxpool1

在这里插入图片描述

  • conv2

在这里插入图片描述

归一化的作用?如何归一化?

归一化的作用:

  • 加速训练过程,加速收敛
  • 提高模型的稳定性
  • 改善模型的泛化能力

所以往往需要使用代码求解数据集的均值和标准差,用于归一化时设置参数。

Learning Rate的含义?

本质上是“步长”。

学习率 大学习率 小
学习速度
理想状态下的使用时间点刚开始训练时一定轮数过后(接近结束时)
不足有可能会出现震荡的情况容易过拟合;收敛速度慢

代码中为什么要声明全局变量(global)?

在Python中,如果想要在函数内部修改一个定义在函数外部的变量,需要使用 global 关键字来声明这个变量是全局的。否则,Python会认为是在函数内部创建了一个与全局变量同名的局部变量。

*什么是钩子函数?

钩子函数是一种回调机制,允许程序在执行的特定点插入用户定义的代码。

在PyTorch中,hook方法有四种:
torch.Tensor.register_hook()
torch.nn.Module.register_forward_hook()
torch.nn.Module.register_backward_hook()
torch.nn.Module.register_forward_pre_hook().

使用.register_forward_hook可以导出卷积特征图。

框选识别

这里使用opencv中提供的.CascadeClassifier()方法,引入haarcascade_frontalcatface.xml文件和haarcascade_frontalcatface_extended.xml文件,用于自动框选出图像中被识别出来的猫脸。

其中的.xml文件可以通过此网址下载:https://github.com/opencv/opencv/tree/master/data/haarcascades

import numpy as np
import cv2 

cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface.xml')
cat_ext_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface_extended.xml')

SF=1.05  # try different values of scale factor like 1.05, 1.3, etc
N=3 # try different values of minimum neighbours like 3,4,5,6

def processImage(image_dir,image_filename):
    # read the image
    img = cv2.imread(image_dir+'/'+image_filename)
    # convery to gray scale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # this function returns tuple rectangle starting coordinates x,y, width, height
    cats = cat_cascade.detectMultiScale(gray, scaleFactor=SF, minNeighbors=N)
    #print(cats) # one sample value is [[268 147 234 234]]
    cats_ext = cat_ext_cascade.detectMultiScale(gray, scaleFactor=SF, minNeighbors=N)
    #print(cats_ext)
    
    # draw a blue rectangle on the image
    for (x,y,w,h) in cats:
        img = cv2.rectangle(img,(x,y),(x+w,y+h),(255,0,0),2)       
    # draw a green rectangle on the image 
    for (x,y,w,h) in cats_ext:
        img = cv2.rectangle(img,(x,y),(x+w,y+h),(0,255,0),2)
    
    # save the image to a file
    cv2.imwrite('out'+image_filename,img)
    
for idx in range(1,7):
    processImage('cats/',str(idx)+'.jpg')
    
processImage('.','dog.jpg')

运行结果如下:(准确率没有预期中高)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1963045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MYSQL多表数据去重、合并、取并集等

SQL join 用于把,来自两个或多个表的行结合起来。 下图展示了 LEFT JOIN、RIGHT JOIN、INNER JOIN、OUTER JOIN 相关的 7 种用法。。 1、SELECT <select_list> FROM TableA A LEFT JOIN TableB B ON A.Key = B.Key 2、SELECT <select_list> FROM TableA A LEFT …

在VMware里面安装Linux安装教程

!(https://gitee.com/code-shuyi/local-images/raw/master/image/202407311020201.png) 5385453)] [外链图片转存中…(img-BiUfrRTp-1722395385453)]

替代TLD5190同步四开关升降压LED专用电源调节器,支持PWM调光功能,具有强制电流调节模式

特征:PC8655替代TLD5190  AEC-Q100合格 −设备环境温度&#xff1a; -40C≤TA≤125C −器件结温&#xff1a; -40C≤TJ≤150C  工作输入电压4.5V至55V&#xff0c;启动电压降至4.5V  在各种条件下效率都很高&#xff0c;高达96%  3%LED电流精度  高侧PMOS调…

C语言中的浮点数存储:深入探讨

案例引入 请看下面一段代码并思考结果&#xff1a; #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> int main() {int n 9;float* pFloat (float*)&n;printf("n的值为&#xff1a;%d\n", n);printf("*pFloat的值为&#xff1a;%f\n", *…

如何实现参加RAG比赛但进不了复赛的总结

今天写这篇文章主要就是总结一下我使用的一些基本方法&#xff0c;虽然肯定比不上前十的大佬们的操作&#xff0c;但对于常规RAG实现来说也是够用的。这次的考题是给了一堆HTML的知识文档&#xff0c;基于这些文档来进行知识问答。这些文档是企业内部的运维相关文档&#xff0c…

点击jmeter.bat一闪而过无法打开的解决方案

重新查看了配置&#xff0c;在系统变量&#xff08;win10以上直接搜索“环境变量”&#xff09;配置了所有的配置&#xff0c;点击jmater.bat一闪而过无法打开&#xff0c;并且在命令行输入jmeter如下的提示&#xff1a; 检查JMETER_HOME在系统变量的配置是否有分号&#xff0c…

windows 上使用纯 nvcc 命令编译 myboyhood/yolo-tensorrt 工程的过程记录

1. 码云仓库链接&#xff1a;https://gitee.com/myboyhood/yolo-tensorrt 2. 参考博客&#xff1a; 1. 用C/C写一个简单的音乐播放器&#xff08;基于windows控制台编程&#xff09;&#xff1a;https://blog.csdn.net/lwx1051046458/article/details/128889992 3. 过程记录&…

Linux中新添加的磁盘信息不显示-主动扫盘(刷新磁盘状态)

在Linux系统中&#xff0c;当你新添加了一个磁盘&#xff08;无论是通过物理添加还是虚拟化环境&#xff09;&#xff0c;你可能需要让系统识别这个新磁盘&#xff0c;并且可能需要更新或“刷新”磁盘的状态。这通常涉及到几个步骤&#xff0c;但没有一个直接的“刷新磁盘状态”…

springboot集成nacos开启权限验证报错:user not found!

按照官网的说明对nacos的application.properties配置做了开启权限配置。 我的配置项&#xff1a; spring:cloud:nacos:discovery: #服务发现配置group: devnamespace: integrated-manage-dev password: integrated_manageusername: integrated_manageserver-addr: lo…

手把手教你实现基于丹摩智算的YoloV8自定义数据集的训练、测试。

摘要 DAMODEL&#xff08;丹摩智算&#xff09;是专为AI打造的智算云&#xff0c;致力于提供丰富的算力资源与基础设施助力AI应用的开发、训练、部署。 官网链接&#xff1a;https://damodel.com/register?source6B008AA9 平台的优势 &#x1f4a1; 超友好&#xff01; …

Java 延迟消息

场景 6S后执行任务 7天后发送订单 从现有时间算延后多少时间开始执行&#xff0c;当然也可以转换为在以后某个时间执行。 Timer类 Java中的Timer类是一个定时器&#xff0c;它可以用来实现延时消息的功能。 import java.util.Timer; import java.util.TimerTask;public c…

uniapp微信小程序本地和真机调试文件图片上传成功但体验版不成功

文章目录 导文是因为要添加服务器域名&#xff01; 导文 uniapp微信小程序本地和真机调试文件图片上传成功但体验版不成功 uniapp微信小程序体验版上传图片不成功 微信小程序本地和真机调试文件图片上传成功但体验版不成功 是因为要添加服务器域名&#xff01; 先看一下 你小程…

android13 第三方桌面不能使用后台历史任务问题 任务键功能失效问题

总纲 android13 rom 开发总纲说明 目录 1.前言 2.复现现象 3.问题分析 4.解决方法 5.编译运行 6.彩蛋 1.前言 随着Android 13操作系统的发布,用户现在可以更加自由地选择和使用第三方Launcher来定制自己的设备。本文将介绍在Android 13上安装和使用第三方Launcher导致…

工信部哪些证书可以考,含金量高吗

随着科技的快速发展和行业的不断变化&#xff0c;市场对人才的需求也在不断更新。技能提升可以帮助个人适应这些变化&#xff0c;满足新的岗位要求。同时学习新技能可以拓宽思维&#xff0c;激发创新意识&#xff0c;帮助我们在工作中找到新的解决方案。 泰迪智能科技专注…

楼宇智能化仿真实训室解决方案

在信息技术的浪潮中&#xff0c;智慧城市作为未来城市发展的新形态&#xff0c;正以前所未有的速度在全球范围内兴起。其中&#xff0c;楼宇智能化作为智慧城市的关键构成&#xff0c;扮演着举足轻重的角色。它不仅提升了建筑的能源效率、安全性与舒适度&#xff0c;还促进了城…

SQL Server 端口配置

目录 默认端口 更改端口 示例&#xff1a;更改 TCP 端口 示例&#xff1a;验证端口设置 远程连接测试 示例&#xff1a;使用 telnet 测试连接 配置防火墙 示例&#xff1a;Windows 防火墙设置 远程连接测试 示例&#xff1a;使用 telnet 测试连接 默认端口 TCP/IP: …

【Github】Github 上commit后 contribution 绿格子不显示 | Github绿格子 | Github贡献度不显示

一、Github 消失的绿点 1、贡献值为什么没了&#xff1f; 2、选择要显示的贡献 如下配置 二、如何解决消失的绿点&#xff1f; 1、添加邮箱 确保邮箱的设置必须选择一个邮箱邮箱 2、git config 添加邮箱 设置邮箱如下&#xff1a; git config --local user.email 316434776…

Tomcat IntelliJ IDEA整合

一、下载及安装Tomcat 下载官网&#xff1a;Apache Tomcat - Welcome! 1.点击红色框中的任意一个版本 2.点击下载 3.解压后放在任意路径&#xff08;我的是放在D盘&#xff09; 4.在bin目录下找到startup.bat&#xff0c;点击启动Tomcat 5.如果双击启动后&#xff0c;终端出…

NC 缺失的第一个正整数

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站&#xff0c;这篇文章男女通用&#xff0c;看懂了就去分享给你的码吧。 描述 给定一个无重…