pytorch 迁移训练自己的数据集

news2024/11/18 21:43:53

1、pytorch 基础训练

上一节为基础
视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别
接着上面一节,我们开始使用迁移学习,训练自己的数据集和保存网络,加载网络并识别。

2、 pytorch加载resnet18

RetNet网络的基础是残差网络,其本架构是ResNet,顾名思义,网络的深度是18层。包括池化,激活,线性,不包括批量归一,池化,那为什么要加载resnet18 ?因为可以使用已经建立的模型,并且微调输出,pytorch里面,在自己收集的图像数据集上训练深度学习模型,可以使用 torchvision 中提供的在 ImageNet 上预训练好的图像分类模型。这样可以节省大量的时间,使用微调模型,最后的resnet18 模型输出层重置,实现了迁移学习。

2.1 数据标准化处理 标准化

定义:数据标准化处理:transforms.Normalize()
数据标准化,一般来说,即均值(mean)为0,标准差(std)为1
简单来说就是将数据按通道进行计算,将每一个通道的数据先计算出其方差与均值,然后再将其每一个通道内的每一个数据减去均值,再除以方差,得到归一化后的结果。标准化处理之后,可以使数据更好的响应激活函数,提高数据的表现力,减少梯度爆炸和梯度消失的出现。

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
#通过设置让内置的cuDNN的auto-tuner自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cudnn.benchmark = True
plt.ion()   # interactive mode 交互模式
#定义三个全局变量
dataloaders=None
dataset_sizes =None
class_names = None

定义标准化函数,里面的数值为resnet 网络标准化数值,不是随便写的。

#标准化函数
data_transforms = {
      'train': transforms.Compose([
         transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
       ]),
      'val': transforms.Compose([
         transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),

      ]),
   } 

2.2 训练函数

接下来进行训练函数的编写,参数已经解释

def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 
    """ 训练模型,并返回在最佳模型 
    参数: 
    - model(nn.Module): 要训练的模型 
    - criterion: 损失函数 
    - optimizer(optim.Optimizer): 优化器 
    - scheduler:                  学习率调度器 
    - num_epochs(int):            最大 epoch 数 
    返回: 
    - model(nn.Module):           最佳模型 
    - best_acc(float):            checkpoint最好准确率 
    """
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # 训练集和验证集交替进行前向传播
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置为训练模式,可以更新网络参数
            else:
                model.eval()   # 设置为预估模式,不可更新网络参数

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据集
            for inputs, labels in dataloaders[phase]:
                global device
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清空梯度,避免累加了上一次的梯度
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # 正向传播
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 反向传播且仅在训练阶段进行优化
                    if phase == 'train':
                        loss.backward() # 反向传播
                        optimizer.step()

                # 统计loss、准确率
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # 发现了更优的模型,记录起来
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # 加载训练的最好的模型
    model.load_state_dict(best_model_wts)
    return model

3、数据集的加载和放置

在data目录下放置两个目录,一个是train,一个是val,显然是训练集和验证集
在这里插入图片描述
三个类别,蚂蚁,蜜蜂和工程车,工程车使用上一篇文章里面的内容
在这里插入图片描述

工程车图像如下所示,放好就行,验证集同理在这里插入图片描述
代码清单如下所示:

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
#通过设置让内置的cuDNN的auto-tuner自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cudnn.benchmark = True
plt.ion()   # interactive mode 交互模式
dataloaders=None
dataset_sizes =None
class_names = None

def imshow(inp, title=None): 
    # 可视化一组 Tensor 的图片
    inp = inp.numpy().transpose((1, 2, 0)) 
    mean = np.array([0.485, 0.456, 0.406]) 
    std = np.array([0.229, 0.224, 0.225]) 
    inp = std * inp + mean 
    inp = np.clip(inp, 0, 1) 
    plt.imshow(inp) 
    if title is not None: 
        plt.title(title) 
    plt.pause(0.001) # 暂停一会儿,为了将图片显示出来

def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 
    """ 训练模型,并返回在验证集上的最佳模型和准确率 
    Args: 
    - model(nn.Module): 要训练的模型 
    - criterion: 损失函数 
    - optimizer(optim.Optimizer): 优化器 
    - scheduler: 学习率调度器 
    - num_epochs(int): 最大 epoch 数 
    Return: 
    - model(nn.Module): 最佳模型 
    - best_acc(float): 最佳准确率 
    """
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # 训练集和验证集交替进行前向传播
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置为训练模式,可以更新网络参数
            else:
                model.eval()   # 设置为预估模式,不可更新网络参数

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据集
            for inputs, labels in dataloaders[phase]:
                global device
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清空梯度,避免累加了上一次的梯度
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # 正向传播
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 反向传播且仅在训练阶段进行优化
                    if phase == 'train':
                        loss.backward() # 反向传播
                        optimizer.step()

                # 统计loss、准确率
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # 发现了更优的模型,记录起来
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # 加载训练的最好的模型
    model.load_state_dict(best_model_wts)
    return model

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            global device
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

def main():
   # 在训练集上:扩充、归一化
   # 在验证集上:归一化
   data_transforms = {
      'train': transforms.Compose([
         transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
       ]),
      'val': transforms.Compose([
         transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),

      ]),
   } 
   data_dir = './data'
   image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                                data_transforms[x])
                        for x in ['train', 'val']}
   global dataloaders
   dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                                    shuffle=True, num_workers=4)
                    for x in ['train', 'val']}
   global dataset_sizes
   dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
   
   global class_names
   class_names = image_datasets['train'].classes

   print(class_names)

   # 获取一批训练数据
   inputs, classes = next(iter(dataloaders['train'])) 
   # 批量制作网格
   out = torchvision.utils.make_grid(inputs) 
   imshow(out, title=[class_names[x] for x in classes]) 


   model = models.resnet18(pretrained=True) # 加载预训练模型
   for param in model.parameters():
        param.requires_grad = False
   num_ftrs = model.fc.in_features # 获取低级特征维度 
   model.fc = nn.Linear(num_ftrs, 3) # 替换新的输出层 
   model = model.to(device) 
   # 交叉熵作为损失函数 
   criterion = nn.CrossEntropyLoss() 
   # 所有参数都参加训练 
   optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 
   # 每过 7 个 epoch 将学习率变为原来的 0.1 
   scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
   model_ft = train_model(model, criterion, optimizer_ft, scheduler, num_epochs=3) # 开始训练 
   visualize_model(model_ft)
   PATH = './test.pth'
   torch.save(model_ft.state_dict(), PATH)



if __name__== "__main__" :
  main()

4、如何调用

我们在前面保存了pth文件,实际上使用了state_dict,他和直接模型保存是有一定的区别的

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torchvision import models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PATH = './test.pth'
transform = transforms.Compose(
    [transforms.Resize((256, 256)),transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])




model = models.resnet18(pretrained=True) # 加载预训练模型
num_ftrs = model.fc.in_features # 获取低级特征维度 
model.fc = nn.Linear(num_ftrs, 3) # 替换新的输出层 
print(device)
model = model.to(device) 
model.load_state_dict(torch.load(PATH))
model.eval()

img = Image.open("./ant.jpg") .convert('RGB')


img = transform(img)
img = img.unsqueeze(0)
img = img.to(device)
with torch.no_grad():
    outputs = model(img)
    _, predicted = torch.max(outputs.data, 1)
    print("the test img lable is ",predicted)

里面unsqueeze函数是必须的,注意这一点,加载图像时,图像通常有3个维度,宽度、高度、和颜色通道数。对于黑白图像,颜色通道数为1,对于彩色图像,有3个颜色通道(红色、绿色和蓝色,RGB)。因此,加载图像并将其存储为张量,维度的顺序是[通道, 高度,宽度],对于二维卷积神经网络来说,三维数据量不能对应。在深度卷积网络中,数据是成批处理的。卷积神经网络不是一次只处理一个图像,而是同时并行处理N个图像。我们称这组图像为批处理。因此,不是维度[C, H, W],而是[N, C, H, W]。,如果一次只处理一个图像,仍然需要将其放入批处理表单中,以便模型接受。如,有一个形状为[3, 255, 255]的图像,则需要将其转换为[1, 3, 255, 255]。这就是unsqueeze(0)函数做的事情。

训练的时候使用的cuda, 在识别的时候可以使用cuda,也可以使用非cuda,也就是"cpu",也是可以的。调用结果如下所示:
python test.py 里面是ant.jpg
在这里插入图片描述

在这里插入图片描述
放一个工程车进去,里面是
在这里插入图片描述

在这里插入图片描述
看见了第三类也就是tensor[2] ,出来了,也就是 tensor[0] 是ant, tensor[1] 是bee,tensor[2] 是工程车

我们使用迁移学习完成了训练和识别,但是这里有一个局限,这是单主要物体识别,没有多分类识别和目标检测,下一篇我们将使用多分类和目标检测来检测物体。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/706755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vscode platformio Arduino开发STM32,点灯+串口调试

1.工具 USB-TTL(非常便宜,几块钱)STM32F103C8T6(几块钱) 2.引脚连线 USB-TTLSTM32TXPA10RXPA9VCC3.3VGNDGND 注意事项: 跳线帽位置:BOOT0接高电平(1),BOOT1接低电平(0)每次上传程序前需要按一下复位键(之后,跳线帽…

互联网编程之基于 TCP 的单线程文件收发程序(CS架构)

目录 需求 服务端实现 客户端实现 测试 需求 可试着根据java编程课所学到的java socket编程技术,尝试编写一个基于 TCP 的单线程文件收发程序,需满足: 服务端程序预先定义好需要发送的文件并等待客户端的连接。 客户端连接成功后&…

ranger配置hive出錯:Unable to connect repository with given config for hive

ranger配置hive出錯:Unable to connect repository with given config for hive 我一開始我以為是我重啟了ranger-admin導致ranger有點問題,後面排查之後發現是我之前把hiveserver2關閉了,所以只需要重新開啟hiveserver2即可

开源 sysgrok — 用于分析、理解和优化系统的人工智能助手

作者:Sean Heelan 在这篇文章中,我将介绍 sysgrok,这是一个研究原型,我们正在研究大型语言模型 (LLM)(例如 OpenAI 的 GPT 模型)如何应用于性能优化、根本原因分析和系统工程领域的问题。 你可以在 GitHub …

一个真实的社会工程学攻击

社会工程学实例 不同于以往通过心理诱骗暗示或欺诈手段社会工程学举例,本次为大家介绍一种特殊的结合刑侦推理及利用技术手段实现的社会工程学实例,可以把它归类为特殊层面的信息收集手段——通过照片确定发拍照人所在的位置,这种社工手段严格…

ExtJS4 相关

2. 程序架构 2.1 目录结构 推荐下面这种目录结构(非强制,如果你足够懂和自信) - appname- app- namespace- Class1.js- Class2.js- ...- extjs- resources- css- images- ...- app.js- index.html appname 包含所有程序代码,是根目录 app 包含所有类&…

MySQL基础篇(day03,复习自用)

MySQL第三天 排序与分页内容练习 多表查询内容练习 排序与分页 内容 #第五章 排序与分页#1.排序 #如果没有使用排序操作,默认情况下查询返回的数据是按照添加数据的顺序显示的。 SELECT * FROM employees;#1.1基本使用 #使用 ORDER BY 对查询到的数据进行排序操作 …

JSON轻量级数据交换格式

文章目录 前言前后端数据交换格式比较JavaScript自定义对象 前言 JSON全名 :JavaScriptObjectNoation JSON在后端和前端中起着非常重要的作用。 在前端,JSON被用作一种数据格式,用于从服务器获取数据并将其展示在网页上; 前端开…

[SpringBoot]单点登录

关于单点登录 单点登录的基本实现思想: 当客户端提交登录请求时,服务器端在验证登录成功后,将生成此用户对应的JWT数据,并响应到客户端 客户端在后续的访问中,将自行携带JWT数据发起请求,通常&#xff0c…

LeetCode 2501 数组中最长的方波 Java

方法一,哈希表枚举 构造哈希集合,记录出现过的数字枚举遍历 import java.util.HashSet; import java.util.Set;class Solution {public int longestSquareStreak(int[] nums) {//构造哈希表集合,记录出现过的数字,转long型&…

kafka入门,Leader 和 Follower 故障处理细节(十四)

Leader 和 Follower 故障处理细节 LEO:每个腹部最后一个offset,leo其实就是最新的offset1 HW:所有副本中最小的LEO Follower故障处理细节 (1)Follower发生故障会被临时提出ISR (2) 这个期间leader和Follower积蓄接收数据 (3) 待该Follower恢复后&#…

嵌入式基础知识

1.嵌入式系统的定义 以应用为中心,以计算技术为基础软硬件可裁剪,适应系统对功能、可靠性、成本、体积、功耗严格要求的专用计算机技术。 主要由嵌入式微控制器、外围硬件设备、嵌入式操作系统以及用户应用软件等部分分组成。 具有“嵌入式” 、“专用…

CSS 内容盒子与边框盒子

这章比较重要,会不断更新❗ 文章目录 👇内容盒子开发者工具的使用border 边框padding 内边距margin 外边距盒子整体尺寸元素默认样式与CSS重置元素分类块级标记行级标记行内块标记 display样式内容溢出裁剪掉溢出部分滚动条 圆角边框 border-radius ✌边…

【MySQL数据库】MMM高可用架构

目录 一 、MMM简介1.1MMM(Master-Master replication manager for MvSQL,MySQL主主复制管理器)1.2关于 MMM 高可用架构的说明如下 二、搭建mysql MMM架构2.1实验环境2.2搭建多主多从2.3安装配置 MySQL-MMM 一 、MMM简介 1.1MMM(M…

【Linux系统编程】粘滞位详解

文章目录 1. 背景2. 准备问题引出 3. 粘滞位4. 思考粘滞位的存在是否是必要的谁可以删除 上一篇文章我们学习了Linux权限相关的内容,这篇文章,我们再来学习一个知识点——粘滞位。 1. 背景 那为了让大家更容易理解粘滞位的概念,首先我们要来…

【国产虚拟仪器】基于ZYNQ7045+V7 FPGA的多通道数据同步采集设计方案(一)

多通道数据采集设备在当前信息数字化的时代应用广泛,各种被测量的信息 如光线、温度、压力、湿度、位置等,都需要经过多通道信号采集系统的采样和 处理,才能被我们进一步分析利用[37]。在一些对采集速率要求较高的军事、航天、 航空、工业制造…

uniapp radio如何实现取消选中

uniapp 内置radio组件明确表示&#xff0c;不能取消选中&#xff0c;那如果要实现取消选中呢&#xff1f; 只要在外层加上label或者其他标签包裹&#xff0c;或者直接加入click事件然后加入事件控制radio的值改变即可 <label class"radio" click"changeAll&…

记录--Threejs-着色器实现一个水波纹

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 hree.js 是一个基于 WebGL 的 JavaScript 3D 库&#xff0c;用于创建和渲染 3D 图形场景。 一、 图像渲染过程 1、webGL webGL: WebGL 是一种基于 JavaScript API 的图形库&#xff0c;它允许在浏览器…

【STM32智能车】电机控制

【STM32智能车】电机控制 PWMPWM基本用法&#xff1a; 电机驱动基本控制基本状态 欢迎收看由咸鱼菌工作室出品的STM32系列教程。本篇内容主要电机控制 PWM 我们要控制电机&#xff0c;就要先了解一下PWM。 PWM(Pulse Width Modulation)控制——脉冲宽度调制技术&#xff0c;通…

全开源的 agv 小车

最近github 有 一个开源的stm32 大荷载小车&#xff0c;从pcb 到代码到cad 文件全部开源 代码在 git 小车的实物图 小车的cad 模型 小车的 abaqus 力学分析模型 pcb 图 控制板接线图 全开源小车 代码在 git