pytorch神经网络训练(AlexNet)

news2025/4/18 18:45:23
  • 导包
import os

import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from PIL import Image

from torchvision import models, transforms
  • 定义自定义图像数据集
class CustomImageDataset(Dataset): 

定义一个自定义的图像数据集类,继承自Dataset

def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

        self.transform = transform

 图像转换方法,用于对图像进行预处理

        self.files = [] 

存储所有图像文件的路径

        self.labels = [] 

存储所有图像的标签

        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

        for index, label in enumerate(os.listdir(main_dir)):

 遍历主目录中的所有子目录

 

          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

定义数据集的长度

        return len(self.files) 

返回文件列表的长度

def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换
transform = transforms.Compose([

    transforms.Resize((227, 227)),  # AlexNet的输入图像大小

    transforms.RandomHorizontalFlip(),  # 随机水平翻转

    transforms.RandomRotation(10),  # 随机旋转

    transforms.ToTensor(),

    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

])

  • 创建数据集
dataset = CustomImageDataset(main_dir="D:\\图像处理、深度学习\\flowers", transform=transform)
  • 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 加载预训练的AlexNet模型
alexnet_model = models.alexnet(pretrained=True)
  • 修改最后几层以适应新的分类任务
num_ftrs = alexnet_model.classifier[6].in_features

alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))
  • 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)
  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:

    alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

alexnet_model.to(device)                                                               

  • 模型评估
def evaluate_model(model, data_loader, device):

    model.eval()  # 将模型设置为评估模式

    correct = 0

    total = 0

    with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度

        for images, labels in data_loader:

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)

            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total

    return accuracy
  • 训练模型
num_epochs = 10

for epoch in range(num_epochs):

    alexnet_model.train()

    running_loss = 0.0

    for images, labels in data_loader:

        images, labels = images.to(device), labels.to(device)

前向传播

        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1819600.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据采集项目1-用户行为数据同步

环境准备 linux配置、克隆103和104、编写集群分发脚本、ssh无密码登录配置、jdk安装、数据模拟集群日志数据输出脚本、xcall脚本、安装hadoop、zk安装、kafka安装、flume安装、mysql安装、maxwell安装、datax安装、hive安装 用户行为数据同步-总的数据流程图 第一层flume 数据…

22 CRT工具安装流程

22 CRT工具安装流程 SecureCRT 9.5 说明书 SecureCRT 9.5是一款由VanDyke Software开发的终端仿真程序。它为Windows、Mac和Linux操作系统提供了强大的SSH(Secure Shell)客户端功能。SecureCRT 9.5提供了对Telnet、RLogin、Serial和X.509等协议的支持&…

没那么简单!浅析伦敦金与美元的关系

伦敦金价与美元的关系可以被比喻为跷跷板的两端,它们的价格走势往往呈现出此消彼长的关系:当美元表现强势的时候,伦敦金的价格可能承受到压力;相反,当美元疲软时,黄金往往会成为避险资产,令伦敦…

Flask快速入门(路由、CBV、请求和响应、session)

Flask快速入门(路由、CBV、请求和响应、session) 目录 Flask快速入门(路由、CBV、请求和响应、session)安装创建页面Debug模式快速使用Werkzeug介绍watchdog介绍快速体验 路由系统源码分析手动配置路由动态路由-转换器 Flask的CBV…

你还在手写数据库文档?推荐一款数据库文档生成工具screw

😄 19年之后由于某些原因断更了三年,23年重新扬帆起航,推出更多优质博文,希望大家多多支持~ 🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Mi…

【调试笔记-20240612-Linux-在 QEMU 中配置 OpenWrt-23.05 支持访问 Windows 宿主机的共享目录】

调试笔记-系列文章目录 调试笔记-20240612-Linux-在 QEMU 中配置 OpenWrt-23.05 支持访问 Windows 宿主机的共享目录 文章目录 调试笔记-系列文章目录调试笔记-20240612-Linux-在 QEMU 中配置 OpenWrt-23.05 支持访问 Windows 宿主机的共享目录 前言一、调试环境操作系统&…

UEditor文件上传超出大小限制修改无效问题

网上说的方法,试过了,不生效 百度ueditor富文本编辑框怎么设置上传图片大小限制_umeditor 控制图片上传不得超过1m-CSDN博客 直接修改此处

[图解]《分析模式》漫谈02-第2章图的多重性错误

1 00:00:01,400 --> 00:00:02,790 今天,我们来看 2 00:00:04,440 --> 00:00:06,190 分析模式的第2章 3 00:00:06,960 --> 00:00:09,820 一个图上面的一些小问题 4 00:00:13,130 --> 00:00:15,320 第2章的图2.4 5 00:00:16,500 --> 00:00:22,190 …

美丽的拉萨,神奇的布达拉宫

原文链接:美丽的拉萨,神奇的布达拉宫 2022年11月30日,可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT-3.5,将人工智能的发展推向了一个新的高度。2023年11月7日,OpenAI首届…

Cloudflare 错误 1006、1007、1008 解决方案 | 如何修复

根据不完全统计,使用 Cloudflare 的网站比例已经接近 20%。因此,在日常工作中,比如进行网页抓取时,您可能经常会遇到一些因 Cloudflare 而产生的困难。例如,遇到 Cloudflare 错误 1006、1007 和 1008,这些错…

Windows下基于Frida查看内存基址和修改寄存器

使用Frida能够方便地获取到DLL基址,还能修改寄存器值。首先要通过任务管理器获得进程的PID,然后写Python脚本把Frida附加到这个PID进程,根据IDA分析出来的函数地址,HOOK到目标函数,修改寄存器的值,最终实现…

PHP聚合通多平台支付平台源码

源码介绍 php聚合通多平台支付平台源码,源码搭建了一下,这个源码不复杂,修改一下数据库账号密码然后导入数据库就可以,和网站恢复备份一样简单! 源码截图 源码下载 PHP聚合通多平台支付平台源码

vite配置unocss

在vue3vitetseslintprettierstylelinthuskylint-stagedcommitlintcommitizencz-git介绍了关于vitevue工程化搭建,现在在这个基础上,我们增加一下unocss unocss官方文档 具体开发中使用遇到的问题可以参考不喜欢原子化CSS得我,还是在新项目中使…

C++面向对象:多态性

多态性 1.概念 多态性是面向对象的程序设计的一个重要特征。在面向对象的方法中一般是这样表述多态的:向不同的对象发送同一个信息,不同的对象在接收时会产生不同的行为。也就是说,每个对象用自己的方式去响应共同的消息。 2.典例 下面这…

免费个人站 独立站 wordpress 自建网站

制作免费网站 | 免费网站构建器 | WordPress.com https://bioinformatics7.wordpress.com WordPress.com

第2章 Rust初体验5/8:match表达式和模式匹配:更富表达力:猜骰子冷热游戏

讲动人的故事,写懂人的代码 2.5 故事3: 比较答案与点数之和 贾克强:“同学们,我们开始用三种语言来实现故事3吧!” 2.5.1 Rust版故事3 这个故事实在是轻松容易地实现了。赵可菲照着书,一下子就写好了。 @@ -1,4 +1,5 @@use rand::Rng; +use std::cmp::Ordering;use std…

PgSQL技术内幕 - psql与服务端连接与交互机制

PgSQL技术内幕 - 客户端psql与服务端连接与交互机制 简单来说,PgSQL的psql客户端向服务端发起连接请求,服务端接收到请求后,fork出一个子进程,之后由该子进程和客户端进行交互,处理客户端的SQL等,并将结果返…

2024/06/13--代码随想录算法3/17|01背包问题 二维、01背包问题 一维、416. 分割等和子集

01背包问题 二维 卡码网链接 动态规划5步曲 确定dp数组(dp table)以及下标的含义:dp[i][j] :从下标为[0,i-1]个物品中任取,放进容量为j的背包,价值总和最大为多少。确定递推公式, 有两个方向可…

JVM常用概念之扁平化堆容器

扁平化堆容器是OpenJDK Valhalla 项目提出的,其主要目标为将值对象扁平化到其堆容器中,同时支持这些容器的所有指定行为,从而达到不影响原有功能的情况下,显著减少内存空间的占用(理想条件下可以减少24倍)。…

使用代理IP常见问题有哪些?

代理IP在互联网数据收集和业务开展中发挥着重要作用,它充当用户客户端和网站服务器之间的“屏障”,可以保护用户的真实IP地址,并允许用户通过不同的IP地址进行操作。然而,在使用代理IP的过程中,用户经常会遇到一些问题…