使用pytorch构建ResNet50模型训练猫狗数据集

news2024/12/23 18:09:06

数据集

1.导包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.加载数据集和模型构建

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

# 训练次数
num_epochs = 10

# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数
    for phase in ['train', 'test']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        # 使用tqdm显示进度条
        with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:
            for inputs, labels in dataloaders[phase]:
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)
                progress_bar.update(1)

        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
     # 更新训练次数计数器
    train_count += 1
    print(f'Training Count: {train_count}')

训练过程

5.预测

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# 定义模型的类别数量
num_classes = 2

# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()

# 图像预处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)

# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)

# 预测
with torch.no_grad():
    out = model(batch_t)

# 获取最高分数的类别
_, index = torch.max(out, 1)

# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练次数
num_epochs = 10

# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数
    for phase in ['train', 'test']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        # 使用tqdm显示进度条
        with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:
            for inputs, labels in dataloaders[phase]:
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)
                progress_bar.update(1)

        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
     # 更新训练次数计数器
    train_count += 1
    print(f'Training Count: {train_count}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1717659.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS ------ 多线程 GCD

一,GCD简介 GCD是Apple开发的一个多线程的较新的解决方案。它主要用于优化应用程序以支持多核处理器以及其他对称处理系统。它是一个在线程池模式的基础上执行的并发任务。 为什么要使用GCD? GCD!可用于多核的并行运算GCD会自动利用更多的…

dockers安装mysql

1.dockerhub上搜索自己需要安装得镜像版本 dockerhub网址:https://hub-stage.docker.com docker pull mysql:5.7 #下载自己需要得版本2.启动容器实例,并且挂载容器数据卷 docker run -d -p 3306:3306 --privilegedtrue \ -v /home/mysql/log:/var/log/…

模拟 CMOS 逆变器的开关功耗

我们不会进一步讨论静态功耗。相反,本文和下一篇文章将介绍 SPICE 仿真,以帮助您更全面地了解逆变器的不同类型的动态功耗。本文重点讨论开关功率——输出电压变化时电容充电和放电所消耗的功率。 LTspice 逆变器实施 图 1 显示了我们将使用的基本 LTsp…

小白跟做江科大32单片机之OLED驱动

原理部分 代码测试 1.江科大老师提供的以下代码文件放入工程中,进行测试 2.正常显示即可

一点连接千家银行,YonSuite让“企业资金”实时在线

用友YonSuite作为全场景SaaS应用服务,是成长型企业实现数智转型的不二选择。多年来,凭借不断的技术革新,通过与千家银行的一站式连接,实现了企业资金的实时在线管理,为成长型企业带来了极大的便利和效益。这一举措不仅…

618适合入手哪些数码好物?实用数码好物清单分享,错过拍烂大腿!

在一年一度的618购物狂欢节里,许多数码爱好者们都在这次盛大的购物盛宴中觅得心仪的数码好物,数码产品不仅改变了我们的生活方式,更让我们享受到了前所未有的便捷和乐趣,那么在这个618,哪些数码好物值得我们入手呢&…

在线IP检测如何做?代理IP需要检查什么?

当我们的数字足迹无处不在,隐私保护显得愈发重要。而代理IP就像是我们的隐身斗篷,让我们在各项网络业务中更加顺畅。 我们常常看到别人购买了代理IP服务后,通在线检测网站检查IP,相当于一个”售前检验““售后质检”的作用。但是…

Golang:使用Base64Captcha生成数字字母验证码实现安全校验

Base64Captcha可以在服务端生成验证码,以base64的格式返回 为了能看到生成的base64验证码图片,我们借助gin go get -u github.com/mojocn/base64Captcha go get -u github.com/gin-gonic/gin文档的示例看起来很复杂,下面,通过简…

豆包浏览器插件会造成code标签内容无法正常显示

启用状态:页面的代码会显示不正常 禁用后,正常显示 害得我重置浏览器设置,一个个测试

leetcode刷题记录28-427. 建立四叉树

问题描述 给你一个 n * n 矩阵 grid ,矩阵由若干 0 和 1 组成。请你用四叉树表示该矩阵 grid 。 你需要返回能表示矩阵 grid 的 四叉树 的根结点。 四叉树数据结构中,每个内部节点只有四个子节点。此外,每个节点都有两个属性: val…

JVM学习-字节码指令集(三)

代码下载 操作数栈管理指令 如同操作一个普通数据结构中的堆栈那样,JVM提供的操作数栈管理指令,可以用于直接操作数栈的指令 将一个或两个元素从栈顶弹出,并且直接废弃:pop,pop2复制栈顶一个或两个数值并将复制值成双份的复制值…

【全开源】餐饮点餐系统源码(ThinkPHP+FastAdmin+UniApp)

开启智能餐饮新时代的钥匙 基于ThinkPHPFastAdminUniApp开发的餐饮点餐系统,主要应用于餐饮,例如早餐、面馆、快餐、零食小吃等快捷扫码点餐需求,标准版本仅支持先付款后就餐模式,高级版本支持先付后就餐和先就餐后付费两种模式。…

昆虫记思维导图,超详细解读

《昆虫记》是法国杰出昆虫学家、文学家法布尔的传世佳作,它不仅是一部研究昆虫的科学巨著,同时也是一部脍炙人口的文学经典。在这部作品中,法布尔以其独特的视角和细腻的笔触,为我们揭示了一个神秘而精彩的昆虫世界。那么&#xf…

重生之 SpringBoot3 入门保姆级学习(11、日志的进阶使用)

重生之 SpringBoot3 入门保姆级学习(11、日志的进阶使用) 3.2.4 文件输出3.2.5 日志文档的归档与切割 3.2.4 文件输出 配置 application.properties # 日志文件名 如果不写路径默认就是在项目根路径建立 demo.log 文件 推荐写法 D:\\demo.log 路径 文…

为什么要使用动态代理IP?

一、什么是动态代理IP? 动态代理IP是指利用代理服务器来转发网络请求,并通过不断更新IP地址来保护访问者的原始IP,从而达到匿名访问、保护隐私和提高访问安全性的目的。动态代理IP在多个领域中都有广泛的应用,能够帮助用户…

面试题:计算机网络中的七四五是什么?

面试题:计算机网络中的七四五是什么? 计算机网络中说的七四五是指:OSI 七层模型、TCP/IP 四层模型、OSI 与 TCP/IP 的综合五层模型 OSI 七层模型 OSI 将计算机网络分为了七层,每一层抽象底层的内容,并遵守一定的规则…

【错题集-编程题】过桥(BFS)

牛客对应题目链接&#xff1a;过桥 (nowcoder.com) 一、分析题目 类似层序遍历的思想。 二、代码 //值得学习的代码 #include <iostream>using namespace std;const int N 2010;int n; int arr[N];int bfs() {int left 1, right 1;int ret 0;while(left < right)…

PMP认证与NPDP认证哪个含金量高?

PMP和NPDP&#xff0c;哪个含金量更高呢&#xff1f; PMP可以全面提升你的职业发展&#xff0c;无论你是技术人员还是项目管理人员&#xff0c;都能帮助你打破思维定式&#xff0c;拓宽视野&#xff0c;并提升管理水平和领导能力。 NPDP不仅帮助个人了解新产品开发流程和原理…

分布式锁的原理和实现(Go)

文章目录 思维导图为什么需要分布式锁&#xff1f;go语言分布式锁的实现Redis自己的实现单元测试 红锁是什么别人的带红锁的实现 etcdzk的实现 面试问题什么是分布式锁&#xff1f;你用过分布式锁吗&#xff1f;你使用的分布式锁性能如何&#xff0c;可以优化吗&#xff1f;怎么…

为什么说OV SSL比DV SSL好

OV SSL证书和DV SSL证书是两种常见的SSL证书类型&#xff0c;它们在验证深度、安全性和可见性等方面存在差异。下面是具体分析&#xff1a; 验证深度 DV SSL&#xff1a;只进行域名所有权的验证。 OV SSL&#xff1a;除了验证域名所有权&#xff0c;还需要验证企业信息。 安…