霹雳吧啦Wz《pytorch图像分类》-p2AlexNet网络

news2024/12/22 20:38:52

《pytorch图像分类》p2AlexNet网络基础及代码

  • 一、零碎知识点
    • 1.过拟合
    • 2.使用dropout后的正向传播
    • 3.正则化regularization
    • 4.代码中所用的知识点
  • 二、总体架构分析
    • 1.ReLU激活函数
    • 2.手算
    • 3.模型代码
  • 三、训练花分类课程代码
    • 1.model.py
    • 2.train.py
    • 3.predict.py

一、零碎知识点

1.过拟合

模型假设过于复杂,参数过多,训练数据过少,噪声过多,导致拟合的函数完美的预测训练集,但对新数据的测试集预测结果差。 过度的拟合了训练数据,而没有考虑到泛化能力。
举个栗子:当我们在学习一门新的学科时会做一些例题,我们把这些例题完完整整背下来,但是一旦给出新的题目,还是不会做,这就是学习没有泛化能力,只记住了例题的细节而忽视了更普遍的规律。

2.使用dropout后的正向传播

dropout会在每一层当中随机失活一部分神经元,从而减少了神经元之间的共适应性,防止过拟合。
nn.Dropout(p= )p代表的是随机失活的比例,默认p=0.5

3.正则化regularization

是一种通过添加额外的约束或惩罚项来控制模型的复杂度的技术,其目的也是防止过拟合。
假设我们的模型是一个二次多项式,我们的目标是最小化损失函数(均方误差MSE),让预测的曲线与真实值尽可能接近。
L2正则化将惩罚项加到损失函数中,使用权重的平方和乘以一个正则化系数λ。
带正则化的损失函数为:
L o s s = M S E + λ ∗ ∣ ∣ w ∣ ∣ 2 Loss = MSE + λ * ||w||^2 Loss=MSE+λ∣∣w2
很抽象,以后再深入学习。

4.代码中所用的知识点

  1. 设备设置
    如果有可以使用的gpu设备,默认使用第一个,没有的话即使用cpu。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  1. RandomResizedCrop随机裁剪
    将裁剪后的图像调整为指定的大小(224x224)
transforms.RandomResizedCrop(224)
  1. root=“. ./. .”
    返回上上层目录

  2. CrossEntropyLoss交叉熵损失函数
    它是PyTorch中的一个损失函数,常用于多分类问题的训练中。它结合了Softmax激活函数和负对数似然损失(NLLLoss)
    详情请见我之前写的博客:多分类问题

  3. net.train( )和net.eval( )
    当调用net.train()时,模型将被设置为训练模式。在训练模式下,模型会启用一些特定的操作,如Batch Normalization归一化处理和Dropout随机失活一些神经元,防止过拟合。而调用net.eval()时,模型将被设置为评估(evaluate)模式。

二、总体架构分析

1.ReLU激活函数

ReLU(Rectified Linear Unit)线性整流函数,又称修正线性单元。在PyTorch中,可以使用torch.nn.ReLU类来表示。
其公式为:
f ( x ) = M a x ( 0 , x ) f(x)=Max(0,x) f(x)=Max(0,x)
它将小于零的输入映射为0,大于等于零的输入保持不变,求导简单方便。
我总结了这些天学习的几个激活函数:激活函数总结

2.手算

在这里插入图片描述

3.模型代码

nn.sequential将一系列的层结构打包组合成一个新的结构
classifer包括了三个全连接层
在这里插入图片描述
1.Conv1 第一个卷积层
代码表示为: nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
彩色图像有rgb三个通道,所以输入通道数为3,输出通道数=卷积核的个数kernel_num
为了方便训练,卷积核个数只取一半,padding直接取2,训练效果与原本的影响不大

 nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2)

2.Maxpool1 第一个池化层
代码表示为:nn.MaxPool2d(kernel_size,stride)

nn.MaxPool2d(kernel_size=3,stride=2)

3.后续代码

nn.Conv2d(48,128,kernel_size=5,padding=2)
nn.MaxPool2d(kernel_size=3,stride=2)
nn.Conv2d(128,192,kernel_size=3,padding=1)
nn.Conv2d(192,192,kernel_size=3,padding=1)
nn.Conv2d(192,128,kernel_size=3,padding=1)
nn.MaxPool2d(kernel_size=3,stride=2)

三、训练花分类课程代码

1.model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=True,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

3.predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "1..jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

1.jpg是郁金香
在这里插入图片描述
2.jpg是向日葵
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1352737.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode刷题---旋转图像

解题思路: 首先对主对角线两边的元素进行交换 接着走一轮遍历,将第1列和第n列进行交换,第2列和第n-1列进行交换,直至得到最终的矩阵。 代码实现: public void rotate(int[][] matrix) {//首先对主对角线的元素进行交换…

企业如何做好客户管理?有哪些关键因素?

客户管理是建立和维护客户关系的重要组成部分,对于企业的发展至关重要。下面就让我们来看看在做好客户管理时有哪些关键因素吧。 第一个关键因素是提供优质的客户服务。无论是线上还是线下,当客户需要帮助时,他们希望能够得到有效且及时的支持…

Spring ApplicationEvent事件处理

Spring的事件 ApplicationEvent以及Listener是Spring为我们提供的一个事件监听、订阅的实现,内部实现原理是观察者设计模式,设计初衷也是为了系统业务逻辑之间的解耦,提高可扩展性以及可维护性。 ApplicationEvent就是Spring的事件接口Applic…

南昌找工作用什么APP或者招聘网站

南昌找工作用吉鹿力招聘网 通过吉鹿力招聘网,可以随时查看最新职位,跟踪简历投递动态,与正在进行招聘的CEO、部门负责人、HR在线沟通,查看其他候选人面试该职位后对面试官、公司环境的面试评价等,为求职者提供参考。 …

阿里云2核2G3M服务器能放几个网站?有限制吗?

阿里云2核2g3m服务器可以放几个网站?12个网站,阿里云服务器网的2核2G服务器上安装了12个网站,甚至还可以更多,具体放几个网站取决于网站的访客数量,像阿里云服务器网aliyunfuwuqi.com小编的网站日访问量都很少&#xf…

2023年度最热 AI 应用 TOP 50,除了 ChatGPT 还有这么多宝藏

原文章链接:年度最热 AI 应用 TOP 50,除了 ChatGPT 还有这么多宝藏 - IT之家 更多消息:AI人工智能行业动态,aigc应用领域资讯 在 AI 工具激烈竞争的一年中,尽管ChatGPT在访问量上遥遥领先,但单次使用时长未…

扫码出入库让仓库管理更加智能化

一、扫码出入库管理的概念 扫码出入库管理系统是一种基于条码技术的仓库管理系统。它通过扫描条码的方式,实现货物的自动化入库、存储、打包和出库。条码技术是一种利用光学识别原理对字符进行识别的技术,具有高效、准确、便捷的特点。 二、扫码出入库能…

B端产品经理学习-需求挖掘

B端产品需求挖掘 目录 识别和管理干系人 决策人和负责人需求挖掘 针对用户进行需求挖掘 用户访谈结果整理 B端产品的需求来源是非常复杂的,要考虑多个方面;如果你是一个通用性的产品,要考虑市场、自身优劣势、干系人。而定制型B端产品会…

react useEffect 内存泄漏

componentWillUnmount() {this.setState (state, callback) > {return;};// 清除reactionthis.reaction();}useEffect 使用AbortController useEffect(() > { let abortController new AbortController(); // your async action is here return () > { abortCo…

2024年阿里云优惠券领取及使用教程

阿里云作为国内领先的云计算服务提供商,一直致力于为客户提供优质、高效的服务。为了更好地回馈客户,阿里云经常会推出各种优惠活动,其中就包括阿里云优惠券。本文将详细介绍如何领取及使用阿里云优惠券。 一、阿里云优惠券介绍 阿里云优惠券…

勒索检测能力升级,亚信安全发布《勒索家族和勒索事件监控报告》

评论员简评 近期(12.08-12.14)共发生勒索事件119起,相较之前呈现持平趋势。 与上周相比,本周仍然流行的勒索家族为lockbit3和8base。在涉及的勒索家族中,活跃程度Top5的勒索家族分别是:lockbit3、siegedsec、dragonforce、8base和…

RocketMQ5.0顺序消息设计实现

前言 顺序消息是 RocketMQ 提供的一种高级消息类型,支持消费者按照发送消息的先后顺序获取消息,从而实现业务场景中的顺序处理。 顺序消息的顺序关系通过消息组(MessageGroup)判定和识别,发送顺序消息时需要为每条消息…

whl is not a supported wheel on this platform.解决办法

1.问题: 安装torch产生 2.解决办法: 使用pip debug --verbose查看 对应的torch版本号 Compatible tags字样,这些就是当前Python版本可以适配的标签。例如,我的Python版本是3.11,可以匹配下面这些文件名:…

YOLOv8改进 更换轻量化模型MobileNetV3

一、MobileNetV3论文 论文地址:1905.02244.pdf (arxiv.org) 二、 MobileNetV3网络结构 MobileNetV3引入了一种新的操作单元,称为"Mobile Inverted Residual Bottleneck",它由一个1x1卷积层和一个3x3深度可分离卷积层组成。这个操…

简易电子琴

#include<reg51.h> //包含51单片机寄存器定义的头文件 sbit P14P1^4; //将P14位定义为P1.4引脚 sbit P15P1^5; //将P15位定义为P1.5引脚 sbit P16P1^6; //将P16位定义为P1.6引脚 sbit P17P1^7; //将P17位定义为P1.7引脚 unsigned char keyval; …

Sam Altman的一天被曝光!每天15小时禁食、服用小剂量安眠药,尽可能避免开会

Sam Altman在经历了几天混乱的管理重组后&#xff0c;重新回到了OpenAI的CEO位置。在日常生活中&#xff0c;奥特曼与许多科技行业高管一样&#xff0c;痴迷于延长自己的寿命。 据报道&#xff0c;他还为应对末日场景&#xff08;致命合成病毒的释放、核战争和人工智能攻击等&…

学习Go语言Web框架Gee总结--前缀树路由Router(三)

学习Go语言Web框架Gee总结--前缀树路由Router router/gee/trie.gorouter/gee/router.gorouter/gee/context.gorouter/main.go 学习网站来源&#xff1a;Gee 项目目录结构&#xff1a; router/gee/trie.go 实现动态路由最常用的数据结构&#xff0c;被称为前缀树(Trie树) 关…

用Audio2Face驱动UE - MetaHuman

新的一年咯&#xff0c;很久没发博客了&#xff0c;就发两篇最近的研究吧。 开始之前说句话&#xff0c;别轻易保存任何内容&#xff0c;尤其是程序员不要轻易Ctrl S 在UE中配置Audio2Face 先检查自身电脑配置看是否满足&#xff0c;按最小配置再带个UE可能会随时崩&#x…

类的生命周期/加载

理一下&#xff0c;java 编译后的字节码文件&#xff0c;我们已经熟悉了 字节码文件长什么样&#xff0c;字节码文件中有哪些内容&#xff0c;那么下一步就是使用类加载器 把字节码文件加载到JVM中 类的生命周期 类的生命周期是 JVM类加载的基础。 加载 所谓加载&#xff0c;…

贝叶斯推断:细谈贝叶斯变分和贝叶斯网络

1. 贝叶斯推断 统计推断这件事大家并不陌生&#xff0c;如果有一些采样数据&#xff0c;我们就可以去建立模型&#xff0c;建立模型之后&#xff0c;我们通过对这个模型的分析会得到一些结论&#xff0c;不管我们得到的结论是什么样的结论&#xff0c;我们都可以称之为是某种推…