PyTorch入门之【AlexNet】

news2025/1/19 11:07:17

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

目录

  • 大纲
  • dataloader
  • model
  • train
  • test

大纲

在这里插入图片描述
各个文件的作用:

  • data就是数据集
  • dataloader.py就是数据集的加载以及实例初始化
  • model.py就是AlexNet模块的定义
  • train.py就是模型的训练
  • test.py就是模型的测试

dataloader

import torch
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np


# define the dataloader
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 16

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


if __name__ == '__main__':
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = next(dataiter)

    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

    # show images
    img_grid = torchvision.utils.make_grid(images)
    img_grid = img_grid / 2 + 0.5
    npimg = img_grid.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

model

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.conv_2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.conv_3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.conv_4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.conv_5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.fc_1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU())
        self.fc_2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc_3= nn.Sequential(
            nn.Linear(4096, num_classes))
        
    def forward(self, x):
        out = self.conv_1(x)
        out = self.conv_2(out)
        out = self.conv_3(out)
        out = self.conv_4(out)
        out = self.conv_5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc_1(out)
        out = self.fc_2(out)
        out = self.fc_3(out)
        return out

if __name__ == '__main__':
    model = AlexNet()
    print(model)
    x = torch.randn(1, 3, 224, 224)
    y = model(x)
    print(y.size())

train

import torch
import torch.nn as nn

from dataloader import train_loader, test_loader
from model import AlexNet


# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3


# load the model
model = AlexNet(num_classes).to(device)


# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  


# train the model
total_len = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        # forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, i+1, total_len, loss.item()
            ))
            
    # Validation
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        model.train()
        print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')

test

import torch

from dataloader import test_loader, classes
from model import AlexNet


# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))

# test the pretrained model on CIFAR-10 test data
with torch.no_grad():
    model.eval()
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1064446.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【将文本编码为图像灰度级别】以 ASCII 编码并与灰度级别位混合将文本字符串隐藏到图像像素的最低位中,使其不明显研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

HTTP进阶,Cookie,响应的回报结果含义,ajax,form表单,不同状态码代表的结果

目录 一、Cookie 二、响应的回报结果含义 三、实际开发中的选择 一、Cookie Cookie是浏览器本地存储数据的一种机制, 在浏览器访问服务器之间,此时你的浏览器对着个服务器之间是一点也不了解的,你的浏览器上是没有任何和着个服务器相关的数据的。 浏览…

【LinuxC】时间、时区,相关命令、函数

文章目录 一、序1.1 时间和时区1.11 时间1.12 时区 1.2 查看时间时区的命令1.21 Windows1.22 Linux 二、C语言函数2.1 通用2.11 函数简介2.12 数据类型简介 2.2 windows 和 Linux特有函数2.3 C语言示例 一、序 1.1 时间和时区 1.11 时间 时间是一种用来描述物体运动变化的量…

怎么把家里闲置旧苹果手机变成家用安防监控摄像头

环境 监控端:苹果5s iOS10.2.1 观看端:苹果11Pro 网络情况:同一个局域网 问题描述 怎么把家里闲置旧苹果手机变成家用安防监控摄像头 解决方案 1.旧苹果5s手机在App Store下载采集端软件 打开app 授权联网、相机、麦克风等权限&#xf…

用VLD调查VC内存泄漏

一、发现内存泄漏 使用VS2022&#xff0c;发现提示有内存泄漏&#xff0c;检查了所有的new&#xff0c;确认都有相应的delete释放。 Detected memory leaks! Dumping objects -> {1914} normal block at 0x0000021FDFFBD2E0, 48 bytes long.Data: < >…

OpenResty安装-(基于Nginx的高性能Web平台,可在Nginx端编码业务)

文章目录 安装OpenResty1.安装1&#xff09;安装开发库2&#xff09;安装OpenResty仓库3&#xff09;安装OpenResty4&#xff09;安装opm工具5&#xff09;目录结构6&#xff09;配置nginx的环境变量 2.启动和运行3.备注 安装OpenResty 1.安装 首先你的Linux虚拟机必须联网 …

【网络安全 --- 工具安装】Centos 7 详细安装过程及xshell,FTP等工具的安装(提供资源)

VMware虚拟机的安装教程如下&#xff0c;如没有安装&#xff0c;可以参考这篇博客安装&#xff08;提供资源&#xff09; 【网络安全 --- 工具安装】VMware 16.0 详细安装过程&#xff08;提供资源&#xff09;-CSDN博客【网络安全 --- 工具安装】VMware 16.0 详细安装过程&am…

高数:第三章:一元函数积分学

文章目录 一、不定积分(一)两个基本概念&#xff1a;原函数、不定积分(二)原函数的存在性&#xff1a;原函数存在定理(三)不定积分的性质(四)基本积分公式(五)三种主要积分法1.凑微分 (第一类换元法)2.换元法 (第二类换元法)①三角代换②根式代换③倒代换 3.分部积分法4.其他技…

工厂管理软件中的PLM管理

第一部分&#xff1a;介绍PLM管理的定义和背景 1.1 定义&#xff1a;PLM管理是指通过工厂管理软件实现对产品生命周期各个阶段的全面管理和协同合作&#xff0c;包括产品设计、工艺规划、生产制造、质量控制、供应链管理等环节。 1.2 背景&#xff1a;随着市场竞争的加剧和消…

希尔排序:优化插入排序的精妙算法

排序算法在计算机科学中扮演着重要的角色&#xff0c;其中希尔排序&#xff08;Shell Sort&#xff09;是一种经典的排序算法。本文将带您深入了解希尔排序&#xff0c;包括其工作原理、性能分析以及如何使用 Java 进行实现。 什么是希尔排序&#xff1f; 希尔排序&#xff0c…

数据结构与算法-(7)---栈的应用-(4)后缀表达式求值

&#x1f308;write in front&#x1f308; &#x1f9f8;大家好&#xff0c;我是Aileen&#x1f9f8;.希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流. &#x1f194;本文由Aileen_0v0&#x1f9f8; 原创 CSDN首发&#x1f412; 如…

【Java 进阶篇】深入了解JDBCTemplate:简化Java数据库操作

数据库操作是几乎所有现代应用程序的一部分。从存储和检索数据到管理业务逻辑&#xff0c;数据库操作是不可或缺的。在Java应用程序中&#xff0c;JDBCTemplate是一种强大的工具&#xff0c;可帮助开发人员轻松进行数据库操作。本文将深入探讨JDBCTemplate&#xff0c;了解它的…

【MySQL】基本查询 (一)

文章目录 一. 基础查询二. where条件子句三. NULL的比较结束语 操作如下表 //创建表结构 mysql> create table exam_result(-> id int unsigned primary key auto_increment,-> name varchar(20) not null comment 同学姓名,-> chinese float default 0.0 comment…

怒刷LeetCode的第24天(Java版)

目录 第一题 题目来源 题目内容 解决方法 方法一&#xff1a;反向遍历 方法二&#xff1a;字符串操作函数 方法三&#xff1a;正则表达式 第二题 题目来源 题目内容 解决方法 方法一&#xff1a;模拟 方法二&#xff1a;递归 方法三&#xff1a;迭代 方法四&…

VS Code更改软件的语言

刚刚安装好的 vscode 默认是英文&#xff0c;可以安装中文扩展包。如图&#xff1a; 重启即可更换为中文。 如果想切换为英文&#xff0c;可以 Ctrl Shift P&#xff0c;打开命令面板。 输入 Configure DIsplay Language&#xff0c;如图&#xff1a; 可以在中英文之间切换…

【C语言】什么是宏定义?(#define详解)

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:C语言 ⚙️操作环境:Visual Studio 2022 ​ 目录 一.什么是宏定义 二.宏定义的组成 第1部分 第2部分 第3部分 三.宏定义的应用 &#x1f38f;类对象宏 &#x1f38f;类函数宏 1.求两个数中的较大值 2.求一个数的…

【短文】在Linux中怎么查看文件信息

2023年10月6日&#xff0c;周五晚上 ls -l filename 通过这条命令可以简略地查看文件信息 stat filename 通过这条命令可以详细地查看文件信息

NFS 原理和配置

NFS 原理介绍 NFS network file system 网络文件系统 基于RPC协议&#xff0c;RPC remote procedure call 。 RPC存在的意义在于解决NFS服务端和客户端通信多端口并且端口不固定的问题。因为NFS的服务端和客户端通信的时候&#xff0c;并不是只有一个端口&#xff…

【云计算网络安全】DDoS 缓解解析:DDoS 攻击缓解策略、选择最佳提供商和关键考虑因素

文章目录 一、前言二、什么是 DDoS 缓解三、DDoS 缓解阶段四、如何选择 DDoS 缓解提供商4.1 网络容量4.2 处理能力4.3 可扩展性4.4 灵活性4.5 可靠性4.6 其他考虑因素4.6.1 定价4.6.2 所专注的方向 文末送书《数据要素安全流通》本书编撰背景本书亮点本书主要内容 一、前言 云…

电脑提示MSVCP100.dll丢失错误怎么解决?分享四个解决方法帮你搞定

在平时我们使用电脑中&#xff0c;经常会遇到各种问题&#xff0c;比如msvcp100.dll文件丢失&#xff0c;那这个msvcp100.dll文件丢失需要怎么修复解决呢&#xff1f;和msvcp100.dll为什么会丢失呢&#xff0c;下面我一点点为大家解答与介绍解决msvcp100.dll丢失问题的方法。 一…