AlexNet——训练花数据集

news2024/11/17 16:28:39

目录

一、网络结构

二、创新点分析

三、知识点

1. nn.ReLU(inplace) 

2. os.getcwd与os.path.abspath 

3. 使用torchvision下的datasets包 

4. items()与dict()用法 

5. json文件  

6. tqdm

7. net.train()与net.val()

四、代码


AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年ImageNet图像分类竞赛中提出的一种经典的卷积神经网络。AlexNet使用了Dropout层,减少过拟合现象的发生。

一、网络结构

二、数据集 

文件存放:

dataset->flower_data->flower_photos

再使用split_data.py 将数据集根据比例划分成训练集和预测集

详细请查看b站up主霹雳吧啦Wz:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification

三、创新点分析

1. deeper网络结构

    通过增加网络深度,AlexNet可以更好的学习数据集的特征,并提高分类的准确率。

2. 使用ReLU激活函数,克服梯度消失以及求梯度复杂的问题。

3. 使用LRN局部响应归一化

    LRN是在卷积与池化层间添加归一化操作。卷积过程中,每个卷积核都对应一个feature map,LRN对这些feature map进行归一化操作。即,对每个特征图的每个位置,计算该位置周围的像素平方和,然后将当前位置像素值除以这个和。LRN可抑制邻近神经元的响应,在一定程度上能够避免过拟合,提高网络泛化能力。

4. 使用Dropout层

Dropout层:在训练过程中随机删除一定比例的神经元,以减少过拟合。Dropout一般放在全连接层与全连接层之间。

四、知识点

1. nn.ReLU(inplace) 默认参数为:inplace=False

inplace=False:不会修改输入对象的值,而是返回一个新创建的对象,即打印出的对象存储地址不同。(值传递)

inplace=True:会修改输入对象的值,即打印的对象存储地址相同,可以节省申请与释放内存的空间与时间。(地址传递)

import torch
import numpy as np
import torch.nn as nn

# id()方法返回对象的内存地址
relu1 = nn.ReLU(inplace=False)
relu2 = nn.ReLU(inplace=True)
data = np.random.randn(2, 4)
input = torch.from_numpy(data)  # 转换成tensor类型
print("input address:", id(input))
output1 = relu1(input)
print("replace=False -- output address:", id(output1))
output2 = relu2(input)
print("replace=True -- output address:", id(output2))
# input address: 1669839583200
# replace=False -- output address: 1669817512352
# replace=True -- output address: 1669839583200

2. os.getcwd与os.path.abspath 

os.getcwd():获取当前工作目录

os.path.abspath('xxx.py'):获取文件当前的完整路径

import os

print(os.getcwd())  # D:\Code
print(os.path.abspath('test.py'))  # D:\Code\test.py

3. 使用torchvision下的datasets包 

train_dataset=datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_transform['train'])

可以得出这些信息: 

4. items()与dict()用法 

items():把字典中的每对key和value组成一个元组,并将这些元组放在列表中返回。

obj = {
    'dog': 0,
    'cat': 1,
    'fish': 2
}
print(obj)  # {'dog': 0, 'cat': 1, 'fish': 2}
print(obj.items())  # dict_items([('dog', 0), ('cat', 1), ('fish', 2)])
print(dict((v, k) for k, v in obj.items()))  # {0: 'dog', 1: 'cat', 2: 'fish'}

5. json文件  

(1)json.dumps:将Python对象编码成JSON字符串

(2)json.loads:将已编码的JSON字符串编码为Python对象

import json

data = [1, 2, 3]
data_json = json.dumps(data)  # <class 'str'>
data = json.loads(data_json)
print(type(data))  # <class 'list'>

6. tqdm

train_bar = tqdm(train_loader, file=sys.stdout)
使用tqdm函数,对train_loader进行迭代,将进度条输出到标准输出流sys.stdout中。可以方便用户查看训练进度。

from tqdm import tqdm
import time

for i in tqdm(range(10)):
    time.sleep(0.1)

7. net.train()与net.val()

net.train():启用BatchNormalization和Dropout

net.eval|():不启用BatchNormalization和Dropout

五、代码

model.py

import torch
import torch.nn as nn

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, padding=2, stride=4),  # input[3,224,224] output[96,55,55]
            nn.ReLU(inplace=True),  # inplace=True 址传递
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[96,27,27]
            nn.Conv2d(96, 256, kernel_size=5, padding=2),  # output[256,27,27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[256,13,13]
            nn.Conv2d(256, 384, kernel_size=3, padding=1),  # output[384,13,13]
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),  # output[384,13,13]
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # output[256,13,13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[256,6,6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(256 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)  # batch这一维度不用,从channel开始
        x = self.classifier(x)
        return x

train.py 

import os
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
from model import AlexNet
import torch.optim as optim
from tqdm import tqdm

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("using:{}".format(device))
    data_transform = {
        'train': transforms.Compose([
            # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定大小
            transforms.RandomResizedCrop(224),
            # 水平方向随机翻转
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    # get data root path
    data_root = os.path.abspath(os.getcwd())  # D:\Code\AlexNet
    # get flower data set path
    image_path = os.path.join(data_root, 'data_set', 'flower_data')  # D:\Code\AlexNet\data_set\flower_data
    # 使用assert断言语句:出现错误条件时,就触发异常
    assert os.path.exists(image_path), '{} path does not exist!'.format(image_path)

    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
    train_num = len(train_dataset)
    val_num = len(val_dataset)

    # write class_dict into json file
    flower_list = train_dataset.class_to_idx
    class_dict = dict((v, k) for k, v in flower_list.items())
    json_str = json.dumps(class_dict)
    with open('class_indices.json', 'w') as file:
        file.write(json_str)

    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)

    net = AlexNet(num_classes=5)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    epochs = 5
    save_path = './model/AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)  # train_num / batch_size
    train_bar = tqdm(train_loader)
    val_bar = tqdm(val_loader)

    for epoch in range(epochs):
        # train
        net.train()
        epoch_loss = 0.0
        # 加入进度条
        train_bar = tqdm(train_loader)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()  # update x by optimizer

            # print statistics
            epoch_loss += loss.item()
            train_bar.desc = 'train eporch[{}/{}] loss:{:.3f}'.format(epoch + 1, epochs, loss)

        # validate
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]  # [1]取每行最大值的索引
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_acc = acc / val_num
        print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, epoch_loss / train_steps, val_acc))

        # find best accuracy
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), save_path)
        print('Train finished!')


if __name__ == '__main__':
    main()

class_indices.json

{"0": "daisy", "1": "dandelion", "2": "roses", "3": "sunflowers", "4": "tulips"}

predict.py 

import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from model import AlexNet

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # load image
    img_path = './2.jpg'
    assert os.path.exists(img_path), "file:'{}' does not exist".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)

    # input [N,C,H,W]
    img = transform(img)
    img = torch.unsqueeze(img, dim=0)

    # read class_indices
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file:'{}' does not exist".format(json_path)
    with open(json_path, 'r') as file:
        class_dict = json.load(file)  # {'0': 'daisy', '1': 'dandelion', '2': 'roses', '3': 'sunflowers', '4': 'tulips'}

    # load model
    net = AlexNet(num_classes=5).to(device)
    # load model weights
    weight_path = './model/AlexNet.pth'
    assert os.path.exists(weight_path), "file:'{}' does not exist".format(weight_path)
    net.load_state_dict(torch.load(weight_path))

    # predict
    net.eval()
    with torch.no_grad():
        output = torch.squeeze(net(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_class = torch.argmax(predict).numpy()
    print_res = 'class:{} probability:{:.3}'.format(class_dict[str(predict_class)], predict[predict_class].numpy())
    plt.title(print_res)
    plt.show()

    for i in range(len(predict)):
        print('class:{:10} probability:{:.3}'.format(class_dict[str(i)], predict[i]))


if __name__ == '__main__':
    main()

Result:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1018544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入学习 Redis Cluster - 基于 Docker、DockerCompose 搭建 Redis 集群,处理故障、扩容方案

目录 一、基于 Docker、DockerCompose 搭建 Redis 集群 1.1、前言 1.2、编写 shell 脚本 1.3、执行 shell 脚本&#xff0c;创建集群配置文件 1.4、编写 docker-compose.yml 文件 1.5、启动容器 1.6、构建集群 1.7、使用集群 1.8、如果集群中&#xff0c;有节点挂了&am…

基于springboot+vue的养老院管理系统 前后端分离项目

养老院管理系统 前后端分离项目 后端技术&#xff1a;springbootmybatis-plusredismysql 前端&#xff1a;vue3.0elementui-plus 【人员管理】 用户管理 客户管理【药品食品管理】 药品管理 食品管理【报修管理】【外出管理】外出管理 访客管理【留言管理】【新闻管理】新闻…

服务器杀掉死进程

清掉服务器上的死进程 查看服务器使用情况 nvidia-smi发现并没有显示进程&#xff0c;那应该是有死进程。 查看占用情况 ps aux会显示如上图的进程列表。 PID进程号STAT状态(S休眠R运行Z死掉了)TIME占用时间(过长的有问题)COMMAND进程启动命令(应该能帮助回忆起这是什么程…

IDEA——工程项目的两种窗口开发模式

文章目录 引言一、多项目窗口模式的便利1.1 源码 debug 二、多项目窗口模式的弊端三、多项目窗口的版本管理四、单项目、多项目窗口模式转换 引言 idea编辑器有两种窗口模式&#xff0c;一种是单项目窗口&#xff0c;另一种是多项目窗口。 我个人使用较多的是单项目窗口&#…

指针和数组笔试题讲解(3)

&#x1f435;本篇文章将对指针相关笔试题进行讲解&#xff0c;同时也是指针和数组笔试题讲解的最后一篇文章&#xff0c;那么接下来将会对8道笔试题进行逐一讲解 笔试题1&#x1f4bb; int main() {int a[5] { 1, 2, 3, 4, 5 };int* ptr (int*)(&a 1);printf("%d…

C#调用C++ dll 返回数组

先看一下C语言函数返回数组的问题&#xff1b; 先看一个错误的示范&#xff1b; 因为 a 是局部变量&#xff0c;只存在函数 function() 中&#xff0c;返回给main中的b是错误的&#xff1b; 函数返回数组的一种写法如下&#xff1b; #include<stdio.h> int function(in…

二维凸包(Graham) 模板 + 详解

&#xff08;闲话&#xff09; 上了大学后没怎么搞oi&#xff0c;从土木跑路到通信了&#xff08;提桶开润大成功&#xff01;&#xff09;&#xff0c;但是一年上两年的课&#xff08;补的&#xff09;&#xff0c;保研也寄掉了&#xff08; 说起来自从博客被大学同学发现并…

地牢大师问题(bfs提高训练 + 免去边界处理的特殊方法)

地牢大师问题 文章目录 地牢大师问题前言题目描述题目分析输入处理移动方式【和二维的对比】边界判断问题的解决 代码总结 前言 在之前的博客里面&#xff0c;我们介绍了bfs 基础算法的模版和应用,这里我们再挑战一下自己&#xff0c;尝试一个更高水平的题目&#xff0c;加深一…

vue2——电商项目 黑马

创建项目 初始化 router app.vue vant 组件库 Viewport 布局 vw适配 路由配置 底部导航组件 二级路由配置 登录页面 新建默认样式 main.js 引入commonless 登录静态页面—头部组件NavBar 导入navbar 引用 axios封装 图形验证码 获取 get 渲染 api接口模块 toast轻提示 使用 …

Flutter图标

https://fluttericon.cn/ Flutter 内置了丰富的图标。 Icon(Icons.ac_unit)

智能批量重命名,轻松删除文件名后缀数字并添加编号!

亲爱的用户们&#xff0c;您是否曾经为繁琐而重复的文件重命名工作而感到头疼&#xff1f;现在&#xff0c;我们为您提供一款智能化的工具&#xff0c;让文件重命名变得如此简单&#xff01; 首先&#xff0c;我们要进入文件批量改名高手&#xff0c;并在板块栏里选择“文件批…

overleaf 插入图片,引用图片,图标标题Fig与文章引用Figure不一致解决

目录 1.一般插图 2.插入双栏图片 3 插入子图 4. 引用出现问题 问题1 &#xff1a; pdf 文中引用只出现了图片序号&#xff0c;如“3”。没有出现“Fig.3 或者Figure.3” 问题2&#xff1a;文中引用的标题和图片下面的标题不一致 1 首先&#xff0c;在导言区添加以下行…

七天学会C语言-第二天(数据结构)

1. If 语句&#xff1a; If 语句是一种条件语句&#xff0c;用于根据条件的真假执行不同的代码块。它的基本形式如下&#xff1a; if (条件) {// 条件为真时执行的代码 } else {// 条件为假时执行的代码 }写一个基础的If语句 #include<stdio.h> int main(){int x 10;…

HarmonyOS开发环境搭建

一 鸿蒙简介&#xff1a; 1.1 HarmonyOS是华为自研的一款分布式操作系统&#xff0c;兼容Android&#xff0c;但又区别Android&#xff0c;不仅仅定位与手机系统。更侧重于万物物联和智能终端&#xff0c;目前已更新到4.0版本。 1.2 HarmonyOS软件编程语言是ArkTS&#xff0c…

STM32DMA原理和应用

目录 1.什么是DMA 2.DMA的意义 3.DMA搬运的数据和方式 4.DMA 控制器和通道 5.DMA通道的优先级 6.DMA传输方式 7.DMA应用 实验一: 内存到内存搬运 CubeMX配置&#xff1a; ​编辑用到的库函数&#xff1a; 代码实现思路&#xff1a; 实验二: 内存到外设搬运 CubeMX…

简单返回封装实体类(RespBean)

RespBean的作用 返回状态码&#xff0c;返回信息&#xff0c;返回数据 package com.example.entity;import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor;Data AllArgsConstructor NoArgsConstructor public class RespBean {private lon…

基于springboot实现的极验校验

概述 在系统业务中&#xff0c;需要想客户发送手机验证码&#xff0c;进行验证后&#xff0c;才能提交。但为了防止不正当的短信发送&#xff08;攻击&#xff0c;恶意操作等&#xff09;&#xff0c;需要在发送短信前添加一个行为验证&#xff08;这里使用的是极验&#xff0…

利用Python将dataframe格式的所有列的数据类型转换为分类数据类型

一、样例理解 import pandas as pd import numpy as np# 创建测试数据 feature_names [col1 , col2, col3, col4, col5, col6] values np.random.randint(20, size(10,6))dataset pd.DataFrame(data values, columns feature_names)print("转换前的数据为\n",d…

【C进阶】指针和数组笔试题解析

做题之前我们先来回顾一下 对于数组名的理解&#xff1a;除了以下两种情况&#xff0c;数组名表示的都是数组首元素的地址 &#xff08;1&#xff09;sizeof&#xff08;数组名&#xff09;&#xff1a;这里的数组名表示整个数组 &#xff08;2&#xff09;&&#xff08;数…

Maven3.6.1下载和详细配置

1.下载maven 说明&#xff1a;以下载maven3.6.1为例 1.1网址 Maven – Welcome to Apache Maven 1.2点击下载 1.3点击Maven 3 archives 1.4 点击相应的版本 1.5 点击binaries下载 说明&#xff1a;binaries是二进制的意思 1.6点击zip格式 1.7 蓝奏云获取 说明&#xff1a…