AlexNet实战

news2025/1/16 16:02:19

前言:之前学了挺多卷积神经网络模型,但是都只停留在概念。代码都没自己敲过,肯定不行,而且这代码也很难很多都看不懂。所以想着先从最先较简单的AlexNet开始敲。不过还是好多没搞明白,之后逐一搞清楚。

文章目录

        • AlexNet 实战
            • Alex.net
            • train.py
            • test.py

AlexNet 实战

image-20230701100109370

  • The first convolutional layer filters the 224 × 224 × 3 input image with 96 kernels of size 11 × 11 × 3 with a stride of 4 pixels (this is the distance between the receptive field centers of neighboring neurons in a kernel map).

  • 第一个卷积层用96个大小为11 × 11 × 3的核对224 × 224 × 3输入图像进行过滤,步幅为4像素(这是核图中相邻神经元的感受野中心之间的距离)。

  • The second convolutional layer takes as input the (response-normalized and pooled) output of the first convolutional layer and filters it with 256 kernels of size 5 × 5 × 48.

  • 第二个卷积层将第一个卷积层的(响应归一化和池化)输出作为输入,并用 256 个大小为 5 × 5 × 48 的内核对其进行过滤。

  • The third, fourth, and fifth convolutional layers are connected to one another without any intervening pooling or normalization layers.

  • 第三个、第四个和第五个卷积层相互连接,没有任何中间池化或归一化层。

  • The third convolutional layer has 384 kernels of size 3 × 3 × 256

  • 第三个卷积层有384个大小为3 × 3 × 256的核

  • The fourth convolutional layer has 384 kernels of size 3 × 3 × 192, and the fifth convolutional layer has 256 kernels of size 3 × 3 × 192. The fully connected layers have 4096 neurons each.

  • 第四卷积层有384个大小为3 × 3 × 192的核,第五卷积层有256个大小为3 × 3 × 192的核。全连接层各有4096个神经元。

  • This is what we use throughout our network, with s = 2 and z = 3.

  • 这是我们在整个网络中使用的,s = 2 和 z = 3。(即池化层,stride=2,kernel_size = 3)

Alex.net

搭建AlexNet就按照论文中的一步一步搭

import torch
from torch import nn
import torch.nn.functional as F

class MyAlexNet(nn.Module):
    def __init__(self):
        super(MyAlexNet,self).__init__()
        ## input [3,224,224]
        self.c1 = nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2) ##output [96,55,55]
        self.s1 = nn.MaxPool2d(kernel_size=3, stride=2) ##output [96,27,27]
        self.Relu = nn.ReLU()


        self.c2 = nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,padding=2) ##output [256,27,27]
        self.s2 = nn.MaxPool2d(kernel_size=3, stride=2) ##output [256,13,13]
        self.c3 = nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,padding=1) ##output [384,13,13]

        self.c4 = nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,padding=1) ##output [384,13,13]

        self.c5 = nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,padding=1) ##output [256,13,13]
        self.s5 = nn.MaxPool2d(kernel_size=3, stride=2) ##output [256,6,6]

        self.flatten = nn.Flatten()
        self.f6 = nn.Linear(256*6*6,4096)
        self.f7 = nn.Linear(4096,4096)
        self.f8 = nn.Linear(4096,2) 

    def forward(self,x):
        x = self.Relu(self.c1(x))
        x = self.s1(x)
        x = self.Relu(self.c2(x))
        x = self.s2(x)
        x = self.Relu(self.c3(x))
        x = self.Relu(self.c4(x))
        x = self.Relu(self.c5(x))
        x = self.s5(x)
        x = self.flatten(x)
        x = self.f6(x)
        x = F.dropout(x,p=0.5)
        x = self.f7(x)
        x = F.dropout(x,p=0.5)
        x = self.f8(x)
        return x
if __name__ == '__main__':
    x = torch.rand([1,3,224,224])
    model = MyAlexNet()
    y = model(x)
train.py
import torch
from torch import nn
from net import MyAlexNet
import numpy as np
from torch.optim import lr_scheduler
import os

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

##解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

Root_train = r"F:/python/AlexNet/data/train"
Root_test = r"F:/python/AlexNet/data/val"

normalize = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomVerticalFlip(),##数据增强
    transforms.ToTensor(),#转换为张量
    normalize
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),#转换为张量
    normalize
])

train_dataset = ImageFolder(Root_train,transform=train_transform)
val_dataset = ImageFolder(Root_test,transform=val_transform)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=32,shuffle=True) ##分批次,打乱


device = 'cuda' if torch.cuda.is_available else 'cpu'
model = MyAlexNet().to(device) 

##定义损失函数
loss_fn = nn.CrossEntropyLoss()

##定义优化器
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)

##学习率每隔10轮变为原来的0.5
lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5)

##定义训练函数
def train(dataloader,model,loss_fn,optimizer):
    loss, current, n = 0.0, 0.0, 0
    for batch,(x,y) in enumerate(dataloader):
        image ,y = x.to(device), y.to(device)
        output = model(image)
        cur_loss = loss_fn(output,y)
        _, pred = torch.max(output,axis=1)
        cur_acc = torch.sum(y==pred)/output.shape[0]

        ##反向传播
        optimizer.zero_grad()
        cur_loss.backward()
        optimizer.step()
        loss +=cur_loss.item()
        current +=cur_acc.item()
        n += 1
    train_loss = loss / n
    train_acc = current / n
    print("train_loss" + str(train_loss))
    print("train_acc" + str(train_acc))
    return train_loss, train_acc

##定义验证函数
def val(dataloader,model,loss_fn):
    model.eval()
    loss, current, n = 0.0, 0.0, 0
    with torch.no_grad():
        for batch,(x,y) in enumerate(dataloader):
            image ,y = x.to(device), y.to(device)
            output = model(image)
            cur_loss = loss_fn(output,y)
            _, pred = torch.max(output,axis=1)
            cur_acc = torch.sum(y==pred)/output.shape[0]
            loss += cur_loss.item()
            current += cur_acc.item()
            n += 1
    val_loss = loss / n
    val_acc = current / n
    print("val_loss" + str(val_loss))
    print("val_acc" + str(val_acc))
    return val_loss, val_acc

##定义画图函数
def matplot_loss(train_loss,val_loss):
    plt.plot(train_loss,label='train_loss')
    plt.plot(val_loss,label="val_loss")
    plt.legend(loc='best')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.title("训练集和验证集loss值对比")
    plt.show()

def matplot_acc(train_loss,val_loss):
    plt.plot(train_loss,label='train_acc')
    plt.plot(val_loss,label="val_acc")
    plt.legend(loc='best')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.title("训练集和验证集acc值对 比")
    plt.show()


##开始训练
loss_train = []
acc_train = []
loss_val = []
acc_val = []

epoch = 20
min_acc = 0
for t in range(epoch):
    lr_scheduler.step()
    print(f"epoch{t+1}\n-----------")
    train_loss,train_acc = train(train_dataloader,model,loss_fn,optimizer)
    val_loss,val_acc = val(val_dataloader,model,loss_fn)

    loss_train.append(train_loss)
    acc_train.append(train_acc)
    loss_val.append(val_loss)
    acc_val.append(val_acc)

    ##保存最好的模型权重
    if val_acc > min_acc:
        folder='save_model'
        if not os.path.exists(folder):
            os.mkdir('save_model')
        min_acc = val_acc
        print(f"save bset model,第{t+1}轮")  
        torch.save(model.state_dict(),'save_model/best_model.pth')

    ##保存最后一轮的权重文件
    if t == epoch-1:
        torch.save(model.state_dict(),'save_model/last_model.pth')
matplot_loss(loss_train,loss_val)
matplot_acc(acc_train,acc_val)
print("Done")

image-20230701193450668

image-20230701193455621

test.py
import torch
from net import MyAlexNet
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader



Root_train = r"F:/python/AlexNet/data/train"
Root_test = r"F:/python/AlexNet/data/val"

normalize = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomVerticalFlip(),##数据增强
    transforms.ToTensor(),#转换为张量
    normalize
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),#转换为张量
    normalize
])

train_dataset = ImageFolder(Root_train,transform=train_transform)
val_dataset = ImageFolder(Root_test,transform=val_transform)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=32,shuffle=True) ##分批次,打乱


device = 'cuda' if torch.cuda.is_available else 'cpu'
model = MyAlexNet().to(device) 

## 加载模型
model.load_state_dict(torch.load("F:/python/AlexNet/save_model/best_model.pth"))

classes = [
    "cat",
    "dog",
]

##把张量转为照片格式
show = ToPILImage()

##进入验证阶段
model.eval()
for i in range(50):
    x, y = val_dataset[i][0],val_dataset[i][1]
    show(x).show()
    x = Variable(torch.unsqueeze(x,dim=0).float(),requires_grad=True).to(device)
    x = torch.tensor(x).to(device)
    with torch.no_grad():
        pred = model(x)
        predicted,actual = classes[torch.argmax(pred[0])],classes[y]
        print(f'predicted:"{predicted}",Actual:"{actual}"')

image-20230701184444246

如果归一化的话,就会出现这种效果,但是如果把normalize = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])去掉就能显示正常图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/719875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spark弹性分布式数据集

1. Spark RDD是什么 RDD(Resilient Distributed Dataset,弹性分布式数据集)是一个不可变的分布式对象集合,是Spark中最基本的数据抽象。在代码中RDD是一个抽象类,代表一个弹性的、不可变、可分区、里面的元素可并行计…

BottomNavigationView(自带角标)无法添加角标的解决问题

android studio的报错提示:java.lang.IllegalArgumentException: The style on this component requires your app theme to be Theme.MaterialComponents (or a descendant). 解决,改主题:

跨境平台做测评、采退、Lu卡、lu货要怎么做安全?

大家好,我是珑哥测评,今天和大家聊聊比较小众的圈子,也就是测评衍生出来的分支,采购和退款。因为最近也有很多客户咨询这个问题,由于沃尔玛风控升级了,很多客户下不成功的问题。 大家都知道无论是做测评还是…

BPM流程引擎适用于哪些类型企业管理系统

看到标题的童鞋们,估计在搜索办公软件系统时都会留意到BPM,那BPM到底是何方神圣?它与管理系统有什么区别呢?今天我们一一解答。 什么是BPM? BPM(即业务流程管理),是企业信息化发展的…

[网络] ifconfig down掉的网口,插上网线网口灯依然亮?

1、软硬件环境 环境1: 硬件: 飞腾E2000Q ARM64 平台 YT8521SH (phy) 软件: linux 4.19.246 环境2: 硬件: NXP T1042 PowerPC 平台 YT8521SH (phy) 软件: linux 4.1.35 备注: 1、环境1的网…

Rdkit|分子可视化

Rdkit|分子可视化 Github: 地址 单个展示 从mol对象到图片:MolToImage(mol, size, kekulize, wedgeBonds, fitImage, …) mol:mol对象 size:图片尺寸,默认(300, 300) kekulize:是否展示kekule形式&#…

Apikit 自学日记:新建 API 请求

进入流程测试用例详情页,点击添加测试步骤,在下拉菜单中选择 添加API请求 ,系统会自动进入API测试步骤编辑页面,接下来你可以编辑API的请求参数、返回结果、校验规则等内容。 设置 API 测试步骤 API测试步骤设置分为以下几个部分…

react笔记_07类组件

目录 复习展开运算符 组件什么叫做组件?分类类组件es6新增构造函数语法类组件渲染类组件的this指向类组件的三大属性state作用语法-初始化数据语法-修改state数据语法-获取state中的数据案例 propspropTypes属性(prop-types库)defaultProps属性 refs[1] 字符串形式的…

搞定HashMap

搞定HashMap 1.Map是个啥? HashMap隶属于Java中集合这一块,我们知道集合这块有list,set和map,这里的HashMap就是Map的实现类,那么在Map这个大家族中还有哪些重要角色呢? 上图展示了Map的家族,…

Cernox 温度传感器碳陶瓷基体结构

Cernox 温度传感器具有高灵敏度、稳定性好、遵循单一电阻与温度曲线,磁场性能优良和耐辐射等特性。适用于低温系统中1.5-375K范围内的测量。传感器在及其严格的质量控制下制造,并在强磁场、中子伽马辐射、热循环和机械耐久条件下证明长期稳定性。与其他可…

基于springboot+mysql+jsp高校社团管理系统

基于springbootmysqljsp高校社团管理系统 一、系统介绍二、所用技术三、功能展示三、其它系统四、获取源码 一、系统介绍 管理员:登录注册、个人中心(个人信息、密码修改、注销)、近期活动(所有活动、文体类活动、学术类活动、公…

微服务学习1——微服务环境搭建

微服务学习1——微服务环境搭建 (参考黑马程序员项目) 个人仓库地址:https://gitee.com/jkangle/springboot-exercise.git 微服务就是将单体应用进一步拆分,拆成更小的服务,拆完之后怎么调用,主流的技术有…

【分布式】zabbix 6.0部署讲解

目录 一、 序章二、zabbix概念2.1 zabbix是什么?2.2 zabbix 监控原理2.3 zabbix 6.0 新特性2.4 zabbix 6.0 功能组件 三、zabbix 6.0 部署部署服务端3.1 部署 Nginx PHP 环境并测试3.1.1 安装nginx3.1.2 安装php3.1.3 修改 Nginx 配置3.1.4 修改 php 配置3.1.5 创建…

Python +selenium 自动化之元素定位

selenium之八大元素定位: 1、通过ID的方式定位 id是页面的唯一标识 例如:找到百度的搜索输入框 driver.findElement(By.id("kw")) 2、通过tagName的方式定位 用标签名定位非常少 ---一般会重复 driver.findElements(By.tagName(&qu…

Vue2 Diff 算法简易版

背景 最近复习的过程中,准备对比一下Vue2和Vue3的diff算法区别,好知道两者直接的差异和优缺点。刚好看了网上的文章,但是对方写的代码不太正确,所以特意记录一下我的学习过程~ 双端比较法 Vue2采用的双端比较法,即新…

MBD开发 STM32 Timer

开两个定时器 一快一慢 两个中断都要使能 没有自动更新,切换下timerx就好了,但是触发UP要手动勾选

剑指offer27.二叉树的镜像

这道题很简单,写了十多分钟就写出来了,一看题目就知道这道题肯定要用递归。先交换左孩子和右孩子,再用递归交换左孩子的左孩子和右孩子,交换右孩子的左孩子和右孩子,其中做一下空判断就行。以下是我的代码:…

爬虫入门指南(8): 编写天气数据爬虫程序,实现可视化分析

文章目录 前言准备工作爬取天气数据可视化分析完整代码解释说明 运行效果完结 前言 天气变化是生活中一个重要的因素,了解天气状况可以帮助我们合理安排活动和做出决策。本文介绍了如何使用Python编写一个简单的天气数据爬虫程序,通过爬取指定网站上的天…

Pandas+Pyecharts | 双十一美妆销售数据分析可视化

文章目录 🏳️‍🌈 1. 导入模块🏳️‍🌈 2. Pandas数据处理2.1 读取数据2.2 数据信息2.3 筛选有销量的数据 🏳️‍🌈 3. Pyecharts数据可视化3.1 双十一前后几天美妆订单数量3.2 双十一前后几天美妆销量3.3…

【Linux】线程终结篇:线程池以及线程池的实现

linux线程完结 文章目录 前言一、线程池的实现二、了解性知识 1.其他常见的各种锁2.读者写者问题总结 前言 什么是线程池呢? 线程池一种线程使用模式。线程过多会带来调度开销,进而影响缓存局部性和整体性能。而线程池维护着多个线程,等待着…