【使用resnet18训练自己的数据集】

news2024/12/29 14:45:52

1.背景及准备

书接上文【以图搜图代码实现】–犬类以图搜图示例 总结了一下可以优化的点,其中提到使用自己的数据集训练网络,而不是单纯使用预训练的模型,这不就来了!!

使用11类犬类微调resnet18网络模型:
1. 数据准备
数据集】11种犬类,共1089张
链接:百度网盘链接
提取码:qlrt
在这里插入图片描述
2. 数据集划分
按照train和val8:2的比例进行划分,划分代码如下:

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :split_data.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30
'''
import os
import shutil
import random

def split_images_into_train_test(source_directory, train_directory, val_directory, train_ratio=0.8):
    """
    将源文件夹下的图片按照指定比例分为训练集和测试集,并分别复制到train和val文件夹下。
    """
    # 确保train和test目录存在,如果不存在则创建
    os.makedirs(train_directory, exist_ok=True)
    os.makedirs(val_directory, exist_ok=True)

    # 获取源文件夹中所有图片文件的列表
    image_files = [f for f in os.listdir(source_directory) if os.path.isfile(os.path.join(source_directory, f))]

    image_files = [f for f in image_files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

    # 打乱图片文件列表的顺序
    random.shuffle(image_files)

    total_images = len(image_files)
    train_images_count = int(train_ratio * total_images)

    # 将图片复制到对应的文件夹下
    for i, image_file in enumerate(image_files):
        source_path = os.path.join(source_directory, image_file)
        if i < train_images_count:
            dest_path = os.path.join(train_directory, image_file)
        else:
            dest_path = os.path.join(val_directory, image_file)
        shutil.copy2(source_path, dest_path)  
        print(f"Copied {image_file} to {os.path.dirname(dest_path)}")


if __name__ == '__main__':
    source_directory = "E:\\xxx\\datas\\imgs"
    train_directory = "E:\\xxx\\datas\\pet_dog\\train"
    val_directory = "E:\\xxx\\datas\\pet_dog\\val"

    file_list = os.listdir(source_directory)
    for file in file_list:
        source=os.path.join(source_directory, file)
        val = os.path.join(val_directory, file)
        train = os.path.join(train_directory, file)
        split_images_into_train_test(source, train, val)

最终效果:
在这里插入图片描述
train和val下的目录结果都是如下图所示,只是数量不一样。
在这里插入图片描述

2.代码实现

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :train.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30  
'''
import torch
import os
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import argparse



def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=5):
    """
    参数:
        model: torch.nn.Module - 要训练的模型实例。
        dataloaders: dict - 包含训练集和验证集的数据加载器,例如{'train': train_loader, 'val': val_loader}。
        criterion: nn.Module - 用于计算损失的函数。
        optimizer: torch.optim.Optimizer - 用于更新模型参数的优化器。
        scheduler: torch.optim.lr_scheduler._LRScheduler - 学习率调度器。
        num_epochs: int - 训练的总轮数。
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置模型为训练模式
            else:
                model.eval()   # 设置模型为评估模式

            running_loss = 0.0
            running_corrects = 0

            i = 0
            for inputs, labels in dataloaders[phase]:
                i+=1
                inputs, labels = inputs.to(device), labels.to(device)

                # 前向传播
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    if i%10==0:
                        print(f"{phase} Loss: {loss:.4f}")

                    _, preds = torch.max(outputs, 1)

                    # 反向传播与优化(仅在训练阶段)
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    print("Training complete.")

    # 训练完成后,保存模型状态字典(包含权重)
    torch.save(model.state_dict(), './weights/resnet18_dog.pth')



def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser(description='使用自己的数据集训练resnet18')

    # 添加参数
    parser.add_argument('--data_dir', type=str, default="E:\HWR_files\datas\pet_dog",
                        help='Path to the dataset directory')
    parser.add_argument('--batch_size', type=int, default=16, help='Input batch size for training (default: 16)')
    parser.add_argument('--num_workers', type=int, default=2, help='Number of workers for data loading (default: 2)')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
    parser.add_argument('--num_epochs', type=int, default=5, help='Number of epochs to train (default: 25)')

    args = parser.parse_args()

    # 数据预处理和加载
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x), data_transforms[x]) for x in
                      ['train', 'val']}
    dataloaders_dict = {
        x: DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) for x
        in ['train', 'val']}

    # 使用ResNet18模型
    model = models.resnet18(pretrained=True)

    # 加载之前保存的权重
    # model.load_state_dict(torch.load('./weights/resnet18_dog.pth'))


    num_features = model.fc.in_features
    model.fc = torch.nn.Linear(num_features, len(image_datasets['train'].classes))  # 修改最后一层全连接层

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # 定义损失函数和优化器
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # 训练模型
    train_model(model, dataloaders_dict, criterion, optimizer, scheduler, args.num_epochs)


if __name__ == '__main__':
    main()

在这里插入图片描述

3.代码测试

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :test.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30 
'''
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

def test_model(weights_path, val_root, batch_size=4):
    """
    使用验证集测试模型性能。
    参数:
    - weights_path: str, 训练好的模型权重文件路径
    - val_root: str, 验证数据集的根目录
    - batch_size: int, 数据加载时的批次大小
    """
    # 设定数据预处理
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 加载验证集
    val_dataset = datasets.ImageFolder(root=val_root, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 初始化模型并加载权重
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 修改最后一层全连接层确保num_classes与实际类别数匹配
    model = models.resnet18()
    num_features = model.fc.in_features
    model.fc = torch.nn.Linear(num_features, len(val_dataset.classes))

    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device)
    model.eval()  # 设置模型为评估模式

    # 测试循环
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # 计算准确率并打印结果
    accuracy = 100 * correct / total
    print(f'Accuracy on validation set: {accuracy}%')

if __name__ == '__main__':

    weights_path = './weights/resnet18_dog.pth'
    val_root = "E:\\xxx\\datas\\pet_dog\\val"

    test_model(weights_path, val_root)

结果图:
在这里插入图片描述
4.效果对比

书接上篇的图像检索:【以图搜图代码实现】–犬类以图搜图示例
来看看有没有准一点的

使用预训练的resnet18:
在这里插入图片描述
离谱了,匹配的前三个都是吉娃娃

看看使用微调之后的resnet18:
对应在上一篇种,模型加载和最后一层的输出个数变成类别数,这里是11。

在这里插入图片描述
哇哇哇!效果显著呀!!!

可以可以,下次尝试使用faiss喽

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2180015.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何实现一个优秀的散列表!

文章内容收录到个人网站&#xff0c;方便阅读&#xff1a;http://hardyfish.top/ 文章内容收录到个人网站&#xff0c;方便阅读&#xff1a;http://hardyfish.top/ 文章内容收录到个人网站&#xff0c;方便阅读&#xff1a;http://hardyfish.top/ 前言 假设现在有一篇很长的…

python-pptx 中 placeholder 和 shape 有什么区别?

在 python-pptx 库中&#xff0c;placeholder 和 shape 是两个核心概念。虽然它们看起来相似&#xff0c;但在功能和作用上存在显著的区别。为了更好地理解这两个概念&#xff0c;我们可以通过它们的定义、使用场景以及实际代码示例来剖析其差异。 Python-pptx 的官网链接&…

08_OpenCV文字图片绘制

import cv2 import numpy as npimg cv2.imread(image0.jpg,1) font cv2.FONT_HERSHEY_SIMPLEXcv2.rectangle(img,(500,400),(200,100),(0,255,0),20) # 1 dst 2 文字内容 3 坐标 4 5 字体大小 6 color 7 粗细 8 line type cv2.putText(img,flower,(200,50),font,1,(0,0,250)…

Kubernetes从零到精通(17-扩展-Operator模式)

目录 一、简介 二、核心概念 三、工作原理 四、Operator Framework SDK示例 1.准备工作 2. 定义MySQLCluster CRD 3. 自定义资源实例 4. 编写控制器逻辑 5. 部署Operator 6. 验证 一、简介 Kubernetes中的Operator模式是一种用于简化和自动化管理复杂应用程序(尤其是…

【最新华为OD机试E卷-支持在线评测】简单的自动曝光(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 💻 ACM金牌🏅️团队 | 大厂实习经历 | 多年算法竞赛经历 ✨ 本系列打算持续跟新华为OD-E/D卷的多语言AC题解 🧩 大部分包含 Python / C / Javascript / Java / Cpp 多语言代码 👏 感谢大家的订阅➕ 和 喜欢�…

数字电路与逻辑设计-移位寄存器逻辑功能测试和应用

一、实验目的 l&#xff0e;验证移位寄存器的逻辑功能&#xff1b; 2&#xff0e;掌握集成电路4位双向移位寄存器的使用方法&#xff1b; 3&#xff0e;学会应用移位寄存器实现数据的串行、并行转换和构成环形计数器。 二、实验原理 l&#xff0e;移位寄存器的特点 寄存器…

深入掌握 Protobuf 与 RPC 的高效结合:实现C++工程中的高效通信

目录 一、Protobuf与RPC框架的通信流程概述二、Protobuf与RPC在C中的实际应用2.1 定义 .proto 文件2.2 编译 .proto 文件生成C代码2.3 实现服务器端逻辑2.4 实现客户端逻辑2.5 使用CMake构建工程2.6 编译与运行2.7 关键组件解析2.8 序列化与反序列化的实现 三、关键实现与解析四…

想不到!手机壁纸变现项目,有人 3 个月怒赚 180000+(附教程)

同学们&#xff01;今天无意间发现了一个超级有潜力的变现数据账号。这个账号专注于制作 3D 立体膨胀壁纸&#xff0c;我实在是忍不住要和大家分享。 这个账号的笔记内容非常简洁&#xff0c;主要就是展示壁纸作品。然而&#xff0c;就是这样简单的内容&#xff0c;却在短短 8…

介绍我经常使用的两款轻便易用的 JSON 工具

第一款是 Chrome Extension&#xff0c;名叫 JSON Viewer Pro&#xff0c;可以在 Chrome 应用商店下载&#xff1a; 点击右上角的 JSON Input&#xff0c;然后可以直接把 JSON 字符串内容粘贴进去&#xff0c;也直接直接加载本地 JSON 文件。 可以在树形显示和图形显示两种模式…

淘宝自动下单退货RPA自动化脚本(已运行两个月)

使用AdsPower Browser写的两个自动化脚本&#xff0c;一个下单一个退货&#xff0c;我也不知道他为什么要做这个自动化脚本&#xff0c;运行2个月时间&#xff0c;还蛮稳定&#xff0c;可以多窗口并发运行! 下单指定淘宝商品连接&#xff0c;执行下单RPA脚本实现自动操作。 退…

模糊测试SFuzz亮相第32届中国国际信息通信展览会

9月25日&#xff0c;被誉为“中国ICT市场的创新基地和风向标”的第32届中国国际信息通信展在北京盛大开幕&#xff0c;本次展会将在为期三天的时间内&#xff0c;为信息通信领域创新成果、尖端技术和产品提供国家级交流平台。开源网安携模糊测试产品及相关解决方案精彩亮相&…

Flux最新ControlNet 高清修复模型测评,效果好速度快!

上一篇介绍了Jasper AI 发布了三个模型中的法线贴图&#xff0c;没看过的可以看一下哈&#xff1a; Flux目前最快ControlNet模型现身&#xff01;法线贴图详细测评 (chinaz.com) 这次再介绍一下另一个模型&#xff1a;升频器&#xff0c;可以有比较好的模糊修复效果&#xff…

一条命令Docker安装常用桌面linux系统含一些系统和应用

分类 一. opens use 15.5 desktop https://hub.docker.com/r/kasmweb/opensuse-15-desktop 这是我最近用的一个,稳定性和性能好过ubuntu,兼容性稍微差,部分依赖无法安装,部分软件运行不起来,界面比ubuntu的要好看.风格是win10的.提供一个开源的webVNC, 可选,但是桌面必定要用…

AIGC专栏16——CogVideoX-Fun V1.1版本详解 支持图文生视频与更大的动态性 为文生视频添加控制

AIGC专栏16——CogVideoX-Fun V1.1版本详解 支持图&文生视频与更大的动态性 为文生视频添加控制 学习前言相关地址汇总源码下载地址HF测试链接 CogVideoX-Fun V1.1详解技术储备Diffusion Transformer (DiT)Stable Diffusion 3EasyAnimate-I2V 算法细节V1.1特点参考图片添加…

20.1 分析pull模型在k8s中的应用,对比push模型

本节重点介绍 : push模型和pull模型监控系统对比为什么在k8s中只能用pull模型的k8s中主要组件的暴露地址说明 push模型和pull模型监控系统 对比下两种系统采用的不同采集模型&#xff0c;即push型采集和pull型采集。不同的模型在性能的考虑上是截然不同的。下面表格简单的说…

二、Spring Boot集成Spring Security之实现原理

Spring Boot集成Spring Security之实现原理 一、Spring Security实现原理概要介绍二、使用WebSecurityConfiguration向Spring容器中注册FilterChainProxy类型的对象springSecurityFilterChain1、未配置securityFilterChain过滤器链时使用默认配置用于生成默认securityFilterCha…

Java SE 总结

Java SE&#xff08;Standard Edition&#xff09;是Java编程语言的标准版本&#xff0c;提供了基础的编程环境和API&#xff0c;适用于开发和运行Java应用程序。下面是Java SE的几个重要方面的知识回顾与总结。 1. Java环境基础 具体可参考这里对三者的介绍 传送门 1.1 JVM…

后端-对表格数据进行添加、删除和修改

一、添加 要求&#xff1a; 按下添加按钮出现一个板块输入添加的数据信息&#xff0c;点击板块的添加按钮&#xff0c;添加&#xff1b;点击取消&#xff0c;板块消失。 实现&#xff1a; 1.首先&#xff0c;设计页面输入框格式&#xff0c;表格首行 2.从数据库里调数据 3.添加…

SpringBoot助力墙绘艺术市场创新

3 系统分析 当用户确定开发一款程序时&#xff0c;是需要遵循下面的顺序进行工作&#xff0c;概括为&#xff1a;系统分析–>系统设计–>系统开发–>系统测试&#xff0c;无论这个过程是否有变更或者迭代&#xff0c;都是按照这样的顺序开展工作的。系统分析就是分析系…

数字化智能工厂应用场景

数字化智能工厂的应用场景广泛&#xff0c;涵盖了多个行业和领域。以下是一些主要的应用场景&#xff1a; 一、制造业 汽车制造&#xff1a;数字化智能工厂在汽车制造业中得到了广泛应用。通过自动化生产线、机器人、物联网和人工智能等技术&#xff0c;汽车制造商能够实现高…