CNN、数据预处理、模型保存

news2025/1/12 10:44:11

目录

  • CNN
    • 代码
      • 读取数据
      • 搭建CNN
      • 训练网络模型
  • 数据增强
  • 迁移学习
    • 图像识别策略
      • 数据读取
      • 定义数据预处理操作
      • 冻结resnet18的函数
      • 把模型输出层改成自己的
      • 设置哪些层需要训练
      • 设置优化器和损失函数
      • 训练
      • 开始训练
      • 再训练所有层
      • 关机了,再开机,加载训练好的模型

CNN

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

读取数据

#定义超参数
input_size=28
num_class=10
num_epochs=3
batch_size=64
#训练集
train_dataset=datasets.MNIST(root='./data',
                             train=True,
                             transform=transforms.ToTensor(),
                             download=True)

test_dataset=datasets.MNIST(root='./data',
                             train=False,
                             transform=transforms.ToTensor())
#构建batch数据
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True) #num_worker=4 使用4个子线程加载数据
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
train_data_iter=iter(train_loader)
#获取训练集的第一个批次数据(第一个快递包)
batch_x,batch_y=next(train_data_iter)
print(batch_x.shape,batch_y.shape)

test_data_iter=iter(test_loader)
batch_x_test,batch_y_test=next(test_data_iter)
print(batch_x_test.shape,batch_y_test.shape)

在这里插入图片描述

搭建CNN

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__() #batch_size,1,28,28
        self.conv1=nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2), #batch_size,16,28,28
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), #batch_size,16,14,14
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,5,1,2), #batch_size,32,14,14
            nn.ReLU(),
            nn.Conv2d(32,32,5,1,2), #batch_size,32,14,14  #输入输出通道不变,让其在隐藏层里面更进一步提取特征
            nn.ReLU(),
            nn.MaxPool2d(2), #batch_size,32,7,7
        )
        self.conv3=nn.Sequential(
            nn.Conv2d(32,64,5,1,2), #batch_size,64,7,7
            nn.ReLU(),
        )
        #batch_size,64*7*7
        self.out=nn.Linear(64*7*7,10)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=nn.Flatten(self.conv3(x))
        output=self.out(x)
        return output
def accuracy(prediction,labels):
    pred=torch.argmax(prediction.data,dim=1) #prediction.data中加data是为了防止数据里面单独数据可能会带来梯度信息
    rights=pred.eq(labels.data,view_as(pred)).sum()
    return rights,len(labels) #(batch_size,)/(batch_size,1)

训练网络模型

net=CNN()

criterion=nn.CrossEntropyLoss() #不需要在CNN中将logistic转换为概率,因为pytorch的交叉熵损失函数会自动进行

optimizer=optim.Adam(net.parameters(),lr=0.001)

for epoch in range(num_epochs):
    train_rights=[]
    for batch_idx,(data,target) in enumerate(train_loader):
        net.train() #进入训练状态,也就是所有网络参数都处于可更新状态
        output=net(data) #output只是logits得分
        
        loss=criterion(output,target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        right=accuracy(output,target)
        train_rights.append(right)
        
        if batch_idx %100 ==0:
            net.eval() #进入评估模式,自动关闭求导机制和模型中的BN层drop out层
            val_rights=[]
            for (data,target) in test_loader:
                output=net(data)
                right=accuracy(output,target)
                val_rights.append(right)
                
            train_r=(sum([tup[0] for tup in train_rights]),sum([tup[1] for tup in train_rights]))
            
            val_r=(sum([tup[0] for tup in val_rights]),sum([tup[1] for tup in val_rights]))
            
            print('当前epoch:{} [{}/{} ({:.0f}%)]\t损失:{:.6f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(epoch,
                                                                                                   batch_idx*batch_size,
                                                                                                  len(train_loader.dataset),
                                                                                                  100.*batch_idx/len(train_loader),
                                                                                                  loss.data,
                                                                                                  100.*train_r[0].numpy()/train_r[1],
                                                                                                  100.*val_r[0].numpy()/val_r[1]))

在这里插入图片描述

数据增强

比如数据不够,可以对数据进行旋转,翻转等操作来添加数据
在这里插入图片描述

迁移学习

例如使用预训练模型
在这里插入图片描述

图像识别策略

输出为102

数据读取

data_dir = './汪学长的随堂资料/2/flower_data/'
train_dir = data_dir + '/train' # 训练数据的文件路径
valid_dir = data_dir + '/valid' # 验证数据的文件路径

定义数据预处理操作

data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize([96, 96]),
            transforms.RandomRotation(45), # 随机旋转, -45~45度之间
            transforms.CenterCrop(64), #对中心进行裁剪,变成64*64
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            
            transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), # 亮度、对比度、饱和度、色调
            transforms.RandomGrayscale(p=0.025), #彩色图变成灰度图
            transforms.ToTensor(), # 0-255 ——> 0-1
            
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #这组均值和标准差是最适合图片进行使用的,因为是3通道所以有3组
        
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([64, 64]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ]),

}
image_datasets

在这里插入图片描述

dataloaders

在这里插入图片描述

dataset_sizes

在这里插入图片描述

model_name = "resnet18" # resnet34, resnet50, 

feature_extract = True #使用训练好的参数

冻结resnet18的函数

def set_parameter_requires_gard(model ,feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
model_ft = models.resnet18() #内置的resnet18
model_ft

改最后一层的,因为默认的是1000输出
在这里插入图片描述

把模型输出层改成自己的

def initialize_model(feature_extract, use_pretrained=True):
    model_ft = models.resnet18(pretrained = use_pretrained)
    
    set_parameter_requires_gard(model_ft, feature_extract)
    
    model_ft.fc = nn.Linear(512, 102)
    
    input_size = 64
    
    return model_ft, input_size
    

设置哪些层需要训练

model_ft, input_size = initialize_model(feature_extract, use_pretrained=True)

device = torch.device("mps") # cuda/cpu

model_ft = model_ft.to(device)

filename = 'best.pt' # .pt .pth

params_to_update = model_ft.parameters()

if feature_extract:
    params_to_update = []
    for name, parm in model_ft.named_parameters():
        if parm.requires_grad == True:
            params_to_update.append(parm)
            print(name)

在这里插入图片描述

model_ft

在这里插入图片描述

设置优化器和损失函数

optimizer_ft = optim.Adam(params_to_update, lr=1e-3)

# 定义学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

criterion = nn.CrossEntropyLoss()
optimizer_ft.param_groups[0]

训练

def train_model(model, dataloaders, criterion, optimizer, num_epochs=50, filename="best.pt"):
    # 初始化一些变量
    since = time.time() # 记录初始时间
    
    best_acc = 0 # 记录验证集上的最佳精度
    
    model.to(device)
    
    train_acc_history = []
    val_acc_history = []
    train_losses = []
    valid_losses = []
    
    LRS = [optimizer.param_groups[0]['lr']]
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        
        # 在每个epoch内,遍历训练和验证两个阶段
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0 # 累积训练过程中的损失
            running_corrects = 0 # 累积训练过程中的正确预测的样本数量
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                preds = torch.argmax(outputs, dim=1)
                
                optimizer.zero_grad()
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)# 整个epoch的平均损失
            epoch_acc = running_corrects.float() / len(dataloaders[phase].dataset) # 整个epoch的准确率
            
            time_elapsed = time.time() - since
            
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f}; ACC: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            
            if phase == "valid" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, filename)

            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
            
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRS.append(optimizer.param_groups[0]['lr'])
        print()
        
        scheduler.step() # 调用学习率调度器来进行学习率更新操作
    
    # 已经全部训练完了
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    
    return model, val_acc_history, train_acc_history, valid_losses, train_losses ,LRS

开始训练


# def train_model(model, dataloaders, criterion, optimizer, num_epochs=50, filename="best.pt"):
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses ,LRS = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5)

在这里插入图片描述
在这里插入图片描述

再训练所有层

# 解冻
for param in model_ft.parameters():
    parm.requires_grad = True

optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) # 每7个epoch, 学习率衰减1/10
criterion = nn.CrossEntropyLoss()
# 加载之间训练好的权重参数
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

在这里插入图片描述

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses ,LRS = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=3)

在这里插入图片描述

关机了,再开机,加载训练好的模型

model_ft, input_size = initialize_model(feature_extract, use_pretrained=True)

filename = 'best.pt'

# 加载模型
checkpoint = torch.load(filename)
model_ft.load_state_dict(checkpoint['state_dict'])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/819228.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络(2) --- 网络套接字UDP

计算机网络(1) --- 网络介绍_哈里沃克的博客-CSDN博客https://blog.csdn.net/m0_63488627/article/details/131967378?spm1001.2014.3001.5501 目录 1.端口号 2.TCP与UDP协议 1.TCP协议介绍 1.TCP协议 2.UDP协议 3.理解 2.网络字节序 发送逻辑…

Node.js之express框架学习心得

Node.js:颠覆传统的服务器端开发 Node.js是基于Chrome V8引擎构建的JavaScript运行时,它采用了完全不同的开发模型。Node.js使用事件驱动和非阻塞I/O的方式处理请求,通过单线程和异步机制,实现高效的并发处理。这意味着在Node.js中,一个线程可以处理数千个并发连接,大大提…

Debian 12.1 “书虫 “发布,包含 89 个错误修复和 26 个安全更新

导读Debian 项目今天宣布,作为最新 Debian GNU/Linux 12 “书虫 “操作系统系列的首个 ISO 更新,Debian 12.1 正式发布并全面上市。 Debian 12.1 是在 Debian GNU/Linux 12 “书虫 “发布六周后推出的,目的是为那些希望在新硬件上部署操作系统…

从内核源码看 slab 内存池的创建初始化流程

slab cache 机制确实比较复杂,涉及到的场景又很多,大家读到这里,我想肯定会好奇或者怀疑笔者在上篇文章中所论述的那些原理的正确性,毕竟 talk is cheap ,所以为了让大家看着安心,理解起来放心,…

让SpringBoot不需要Controller、Service、DAO、Mapper,卧槽!这款工具绝了!

Dataway介绍 Dataway 是基于 DataQL 服务聚合能力,为应用提供的一个接口配置工具。使得使用者无需开发任何代码就配置一个满足需求的接口。整个接口配置、测试、冒烟、发布。一站式都通过 Dataway 提供的 UI 界面完成。UI 会以 Jar 包方式提供并集成到应用中并和应…

在windows下安装ruby使用gem

在windows下安装ruby使用gem 1.下载安装ruby环境2.使用gem3.gem换源 1.下载安装ruby环境 ruby下载地址 选择合适的版本进行下载和安装: 在安装的时候,请勾选Add Ruby executables to your PATH这个选项,添加环境变量: 安装Ruby成…

vue-print-nb使用(实现分页打印)

参考链接&#xff1a;vue-print-nb - npm (npmjs.com)https://www.npmjs.com/package/vue-print-nb 一、安装 1、Vue2安装 npm install vue-print-nb --save <!-- 全局配置&#xff1a;main.js --> import Print from vue-print-nb // Global instruction Vue.use(P…

解码“平台工程”,VMware 有备而来

随着全球数字化进程加快&#xff0c;企业使用前沿技术加快商业创新&#xff0c;以提高竞争力。其中如何加快开发效率&#xff0c;为客户创造更多价值成为新的关注焦点。 继DevOps后&#xff0c;“平台工程”&#xff08;Platform Engineering&#xff09; 一词引发热议。平台工…

Redis的安装部署以及基本的使用

目录 一、Linux下直接安装Redis &#xff08;1&#xff09;下载Redis安装包 &#xff08;2&#xff09;安装GCC编译环境 &#xff08;3&#xff09;安装Redis &#xff08;4&#xff09;服务启动 &#xff08;5&#xff09;后台启动 二、使用Docker安装部署Redis &…

火车头采集器免费版【php源码】

大家好&#xff0c;小编来为大家解答以下问题&#xff0c;python turtle circle 画半圆圆心在哪&#xff0c;python中用turtle画一个圆形&#xff0c;现在让我们一起来看看吧&#xff01; 1、t.circle(100,180)的意思&#xff1f; t.circle(100, 180)是Python中turtle库中的一…

18- C++ 强制类型转换-6 (C++)

第八章 强制类型转换 c提供了 隐式类型转换&#xff0c;所谓隐式类型转换&#xff0c;是指不需要用户干预&#xff0c;编译器默认进行的类型转换行为&#xff08;很多时候用户可能都不知道到底进行了哪些转换&#xff09;。例如&#xff1a; int nValue 8; double dValue 10…

评估修改后的YOLOv8模型的参数量和速度

YOLOv8公布了自己每个模型的速度和参数量 那么如果我们自己对YOLOv8做了一些修改&#xff0c;又怎么样自己写代码统计一下修改后的模型的参数量和速度呢&#xff1f; 其实评估这些东西&#xff0c;大多数情况下不需要我们从头自己写一个函数来评估 一般来说&#xff0c;只要…

【云存储】使用OSS快速搭建个人网盘教程(阿里云)

使用OSS快速搭建个人网盘 一、基础概要1. 主要的存储类型1.1 块存储1.2 文件存储1.3 对象存储 2. 对象存储OSS2.1 存储空间2.2 地域2.3 对象2.4 读写权限2.5 访问域名&#xff08;Endpoint&#xff09;2.6 访问密钥2.7 常用功能&#xff08;1&#xff09;创建存储空间&#xff…

HCIP-datacom-831题库

考取HCIP数通证书可以胜任中到大型企业网络工程师岗位&#xff0c;需要掌握中到大型网络的特点和通用技术&#xff0c;具备使用华为数通设备进行中到大型企业网络的规划设计、部署运维、故障定位的能力&#xff0c;并能针对网络应用设计出较高安全性、可用性和可靠性的解决方案…

RedisJava的Java客户端

目录 1.Jedis的使用 前置工作-ssh进行端口转发 JedisAPI的使用 Jedis连接池 2.SpringDataRedis的使用 1.创建项目 2.配置文件 3.注入RedisTemplate对象 4.编写代码 3.SpringRedisTemplate 哈希结构用法 ​总结 1.Jedis的使用 Jedis&#xff1a;以Redis命令作为方法…

途乐证券:沪指强势拉升涨0.63%,券商等板块走强,传媒板块活跃

31日早盘&#xff0c;两市股指全线走高&#xff0c;沪指一度涨超1%收复3300点&#xff0c;上证50指数盘中涨逾2%&#xff1b;随后涨幅有所收窄&#xff1b;两市成交额显着放大&#xff0c;北向资金净买入超90亿元。 到午间收盘&#xff0c;沪指涨0.63%报3296.58点&#xff0c;深…

Python多线程与GIL锁

Python多线程与GIL锁 python多线程 Python的多线程编程可以在单个进程内创建多个线程来同时执行多个任务&#xff0c;从而提高程序的效率和性能。Python的多线程实现依赖于操作系统的线程调度器&#xff0c;并且受到全局解释器锁&#xff08;GIL&#xff09;的限制&#xff0c…

如何在 Ubuntu 22.04 下编译 StoneDB for MySQL 8.0 | StoneDB 使用教程 #1

作者&#xff1a;双飞&#xff08;花名&#xff1a;小鱼&#xff09; 杭州电子科技大学在读硕士 StoneDB 内核研发实习生 ❝ 大家好&#xff0c;我是 StoneDB 的实习生小鱼&#xff0c;目前正在做 StoneDB 8.0 内核升级相关的一些事情。刚开始接触数据库开发没多久&#xff0c…

第55步 深度学习图像识别:CNN特征层和卷积核可视化(TensorFlow)

基于WIN10的64位系统演示 一、写在前面 &#xff08;1&#xff09;CNN可视化 在理解和解释卷积神经网络&#xff08;CNN&#xff09;的行为方面&#xff0c;可视化工具起着重要的作用。以下是一些可以用于可视化的内容&#xff1a; &#xff08;a&#xff09;激活映射&…