pytorch神经网络训练(LeNet-5)

news2024/9/19 10:34:40

LeNet-5

  1. 导包
import os

import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from PIL import Image

from torchvision import transforms
  1. 定义自定义图像数据集
class CustomImageDataset(Dataset):

    def __init__(self, main_dir, transform=None):

        self.main_dir = main_dir

        self.transform = transform

        self.files = []

        self.labels = []

        self.label_to_index = {}

定义一个自定义的图像数据集类,继承自Dataset,初始化方法,接收主目录和转换方法,主目录,包含多个子目录,每个子目录包含同一类别的图像,图像转换方法,用于对图像进行预处理,存储所有图像文件的路径,存储所有图像的标签,创建一个字典,用于将标签映射到索引

  for index, label in enumerate(os.listdir(main_dir)):

            self.label_to_index[label] = index

            label_dir = os.path.join(main_dir, label)

            if os.path.isdir(label_dir):

                for file in os.listdir(label_dir):

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label)

遍历主目录中的所有子目录,将标签映射到索引,构建标签子目录的路径,如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

        return len(self.files)

定义数据集的长度,返回文件列表的长度

def __getitem__(self, idx):

        image = Image.open(self.files[idx])

        label = self.labels[idx]

        if self.transform:

            image = self.transform(image)

        return image, self.label_to_index[label]

定义获取数据集中单个样本的方法,打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  1. 定义数据转换
transform = transforms.Compose([

    transforms.Resize((32, 32)),  # LeNet-5的输入图像大小

    transforms.ToTensor(),

    transforms.Normalize(mean=[0.5], std=[0.5]),  # LeNet-5的标准化

])
  1. 创建数据集
dataset = CustomImageDataset(main_dir="D:\图像处理、深度学习\cat.dog", transform=transform)
  1. 创建数据加载器
train_data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

test_data_loader = DataLoader(dataset, batch_size=32, shuffle=False)
  1. 定义LeNet-5模型
class LeNet5(nn.Module):

    def __init__(self, num_classes):

        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)

        self.fc2 = nn.Linear(120, 84)

        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):

        x = self.pool(F.relu(self.conv1(x)))

        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 16 * 5 * 5)

        x = F.relu(self.fc1(x))

        x = F.relu(self.fc2(x))

        x = self.fc3(x)

        return x
  1. 实例化模型
num_classes = len(dataset.label_to_index)

lenet_model = LeNet5(num_classes)
  1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(lenet_model.parameters(), lr=0.001)
  1. 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:

    lenet_model = nn.DataParallel(lenet_model)
  1. 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    lenet_model.to(device)
  1. 训练模型
num_epochs = 10

    for epoch in range(num_epochs):

        lenet_model.train()

        running_loss = 0.0

        for images, labels in train_data_loader:

            images, labels = images.to(device), labels.to(device)

前向传播

反向传播和优化

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()



            running_loss += loss.item()



        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_data_loader):.4f}')

在每个周期结束时评估模型

 

       lenet_model.eval()

        correct = 0

        total = 0

        with torch.no_grad():

            for images, labels in test_data_loader:

                images, labels = images.to(device), labels.to(device)

                outputs = lenet_model(images)

                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total

        print(f'Test Accuracy: {accuracy:.2f}%')
  1. 保存训练好的模型
torch.save(lenet_model.state_dict(), "D:\图像处理、深度学习\训练保存\lenet_model.pth")
  1. 导包
import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import DataLoader

from torchvision.datasets import ImageFolder

from torchvision import transforms
  1. 定义LeNet-5模型
class LeNet5(nn.Module):

    def __init__(self, num_classes):

        super(LeNet5, self).__init__()

        self.convnet = nn.Sequential(

            nn.Conv2d(3, 6, kernel_size=5),

            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(6, 16, kernel_size=5),

            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2)

        )
     

        self.fc = nn.Sequential(

            nn.Linear(16 * 5 * 5, 120),

            nn.ReLU(),

            nn.Linear(120, 84),

            nn.ReLU(),

            nn.Linear(84, num_classes)

        )

    def forward(self, x):

        x = self.convnet(x)

        x = x.view(x.size(0), -1)  # 展平多维卷积层输出

        x = self.fc(x)

        return x
  1. 定义数据转换
transform = transforms.Compose([

    transforms.Resize((32, 32)),  # LeNet-5的输入图像大小

    transforms.ToTensor(),

])
  1. 假设您的数据集是一个ImageFolder格式,并且路径为 "path_to_your_dataset"
dataset = ImageFolder(root="D:\图像处理、深度学习\cat.dog", transform=transform)
  1. 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  1. 获取数据集中的类别数
num_classes = len(dataset.classes)
  1. 创建LeNet-5模型实例
lenet_model = LeNet5(num_classes=num_classes)
  1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(lenet_model.parameters(), lr=0.001)
  1. 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

lenet_model.to(device)

  1. 训练模型
num_epochs = 10

for epoch in range(num_epochs):

    lenet_model.train()

    running_loss = 0.0

    correct = 0

    total = 0

    for images, labels in data_loader:

        images, labels = images.to(device), labels.to(device)

        # 前向传播

        outputs = lenet_model(images)

        loss = criterion(outputs, labels)

        # 反向传播和优化

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)

        correct += (predicted == labels).sum().item()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1860151.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git的安装以及使用

一.简单介绍 1.1版本控制 版本控制是指对软件开发过程中各种程序代码,配置文件及说明文档等文件变更管理,是软件配置管理的核心思想之一。 版本控制最重要的内容是追踪文件的变更,它将什么时候,什么人更改了文件的什么内容等信息忠实的记录…

社交小心机:特别的动态给特别的她/他

在社交媒体盛行的今天,微信朋友圈成了我们分享生活点滴的重要平台。 但是,你是否有过这样的烦恼——有些动态只想和特定的人分享,而不是所有人?别担心,今天我就来教大家如何巧妙地设置朋友圈权限,让你的分…

【2024.6.25】今日 IT之家精选新闻

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

C语言 循环语句while 2

应用实例 int main() {char password[20] { 0 };printf("输入密码:>");scanf("%s", password);printf("请确认输入密码(Y/N):>");//清理缓存区int tmp 0;while ((tmp getchar()) ! \n){;}int ch getchar();if (ch Y){pri…

RAG实践 - 搭建本地知识库 - Ollama + AnythingLLM

0,什么是RAG? RAG,即检索增强生成(Retrieval-Augmented Generation),是一种先进的自然语言处理技术架构,旨在克服传统大型语言模型(LLM)在处理开放域问题时的信息容量限…

java 多线程入门

对于 Java 初学者来说,多线程的很多概念听起来就很难理解。比方说: 进程,是对运行时程序的封装,是系统进行资源调度和分配的基本单位,实现了操作系统的并发。线程,是进程的子任务,是 CPU 调度和…

提示缺少Microsoft Visual C++ 2019 Redistributable Package (x64)(下载)

下载地址:这个是官网下载地址:Microsoft Visual C 2019 Redistributable Package (x64) 步骤: 第一步:点开链接,找到下图所示的东西 第二步:点击保存下载 第三步:双击运行安装 第四步&#xf…

让工厂像手机一样更“聪明”

手机,作为我们日常生活中不可或缺的一部分,以其智能、便捷、高效的特点,彻底改变了我们的沟通、娱乐和工作方式。那么,想象一下,如果工厂能像手机一样便捷,那么生产过程中的每一个环节都将变得触手可及。通…

揭秘Redis中的高级数据结构:跳跃表Skiplist

Redis数据结构-跳跃表Skiplist 1. 简介1.1. Redis高性能键值存储数据库1.2. Redis的特点和优势1.3. 跳跃表Skiplist 2. 跳跃表的概念和背景2.1 跳跃表的概念2.2 跳跃表的发展历程和提出背景 3. 跳跃表的基本原理3.1 结构概述3.1.1 跳跃表的结构概述3.1.2 跳跃表的节点结构 3.2 …

C#语言+net技术架构+ VS2019开发的微信公众号预约挂号系统源码 微信就医全流程体验 什么是微信预约挂号系统?

C#语言net技术架构 VS2019开发的微信公众号预约挂号系统源码 微信就医全流程体验 什么是微信预约挂号系统? 微信预约挂号系统是一种基于互联网的预约挂号平台,通过与医院信息系统的对接,实现了患者通过手机微信轻松预约挂号的功能。预约挂号系…

【AI大模型】Transformers大模型库(十一):Trainer训练类

目录 一、引言 二、Trainer训练类 2.1 概述 2.2 使用示例 三、总结 一、引言 这里的Transformers指的是huggingface开发的大模型库,为huggingface上数以万计的预训练大模型提供预测、训练等服务。 🤗 Transformers 提供了数以千计的预训练模型&am…

基于FreeRTOS+STM32CubeMX+LCD1602+MCP4152(SPI接口)的数字电位器Proteus仿真

一、仿真原理图: 二、仿真效果: 三、软件部分: 1)、时钟配置初始化: void SystemClock_Config(void) { RCC_OscInitTypeDef RCC_OscInitStruct = {0}; RCC_ClkInitTypeDef RCC_ClkInitStruct = {0}; /** Initializes the CPU, AHB and APB busses clocks */ RCC…

同城购物优惠联盟返现系统小程序源码

:省钱购物新体验 🎉一、同城优惠,一网打尽 在繁华的都市生活中,你是否总是为寻找各种优惠而费尽心思?现在,有了“同城优惠联盟返现小程序”,你可以轻松掌握同城各类优惠信息。无论是餐饮、购物…

解题思路:LeetCode 第 209 题 “Minimum Size Subarray Sum“

解题思路:LeetCode 第 209 题 “Minimum Size Subarray Sum” 在这篇博文中,我们将探讨如何使用 Swift 解决 LeetCode 第 209 题 “Minimum Size Subarray Sum”。我们会讨论两种方法:暴力法和滑动窗口法,并对这两种方法的时间复…

Arduino - 串行绘图仪

Arduino - Serial Plotter Arduino - 串行绘图仪 In this tutorial, we will learn how to use the Serial Plotter on Arduino IDE, how to plot the multiple graphs. 在本教程中,我们将学习如何在Arduino IDE上使用串行绘图仪,如何绘制多个图形。 A…

【软件工程】【22.04】p2

关键字: 软件开发分本质及涉及问题、需求规约与项目需求不同、用况图概念包含模型元素及其关系、创建系统的用况模型RUP进行活动、软件生存周期&软件生存周期模型&软件项目过程管理关系、CMMI基本思想 模块结构图:作用域、控制域;语…

vue2 antd 开关和首页门户样式,表格合计

1.首页门户样式 如图 1.关于圆圈颜色随机设置 <a-col :span"6" v-for"(item, index) in menuList" :key"index"><divclass"circle":style"{ borderColor: randomBorderColor() }"click"toMeRouter(item)&qu…

版本控制工具-git分支管理

目录 前言一、git分支管理基本命令1.1 基本命令2.1 实例 二、git分支合并冲突解决三、git merge命令与git rebase命令对比 前言 本篇文章介绍git分支管理的基本命令&#xff0c;并说明如何解决git分支合并冲突&#xff0c;最后说明git merge命令与git rebase命令的区别。 一、…

Python重拾

1.Python标识符规则 字母&#xff0c;下划线&#xff0c;数字&#xff1b;数字不开头&#xff1b;大小写区分&#xff1b;不能用保留字&#xff08;关键字&#xff09; 2.保留字有哪些 import keyword print(keyword.kwlist)[False, None, True, and,as, assert, async, await…

【AI兼职副业必看,行业分析+注意事项+具体应用,想要做点副业的小白必看!】

前言 随着AI技术的日新月异&#xff0c;它已悄然渗透到我们生活的每一个角落&#xff0c;成为了我们日常生活和工作中的得力助手。在当前经济下行的环境下&#xff0c;AI技术更是成为了提升工作效率、拓展业务领域的关键。对于我们普通人而言&#xff0c;有效利用AI工具&#…