深度学习基础案例7--马铃薯病识别,对VGG16进行轻量级优化,计算量减少了99%,但是准确率下降4%

news2025/1/18 16:53:15
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • 本来想继续优化的,但是我看论文和查阅一些资料,涉及到了知识蒸馏、量化的知识,这些知识我需要花一点时间去研究一下,才能进一步优化。

🍺 实验要求:

  1. 自己搭建VGG-16网络框架
  2. 如何查看模型的参数量以及相关指标

🍻 拔高(可选):

  1. 验证集准确率达到100%
  2. 使用PPT画出VGG-16算法框架图

🔎 探索(难度有点大)

  1. 在不影响准确率的前提下轻量化模型
  • 目前VGG16的Total params是134,272,835

🏩 完成度:

  • 三个任务均已完成
  • 自己搭建试验成果:训练集准确率100%,测试集准确率99.1%
  • 提高:分别对分类层、卷积层进行了优化,但是结果却没有啥变化
  • 探索:全面修改分类层,计算量减少了99%,但是Train_acc:95.8%, Test_acc:94.9%,准确率降低了4%左右

📟 论文参考

[1]方宇伦,陈雪纯,杜世昌,等.基于轻量化深度学习VGG16网络模型的表面缺陷检测方法[J].机械设计与研究,2023,39(02):143-147.DOI:10.13952/j.cnki.jofmdr.2023.0068.

文章目录

  • vgg16算法简介
  • 实验验证
    • 1、数据处理
      • 1、导入库
      • 2、查看文件要分类的类型
      • 3、展示图片
      • 4、导入全部数据与数据处理
      • 5、数据划分
      • 6、动态加载数据
    • 2、VGG16模型的搭建
    • 3、模型训练
      • 1、模型训练搭建
      • 2、模型测试构建
      • 3、设置超参数
      • 4、正式训练
    • 4、结果可视化
    • 5、优化
      • 优化一
      • 优化二
    • 6、轻量级优化
        • 优化前的分类层
          • 参数计算
          • 总参数数量
        • 优化后的分类层
          • 参数计算
          • 总参数数量
        • 轻量化程度
        • 参数减少比例
        • 计算复杂度

vgg16算法简介

vgg16是CNN卷积神经网络的经典架构,它拥有13层卷积,3层池化构成,下图是本人第一次画的神经网络🤠🤠🤠🤠

在这里插入图片描述

实验验证

1、数据处理

1、导入库

import torch 
import torchvision 
import torch.nn as nn
import numpy as np 

device = ('cuda' if torch.cuda.is_available() else 'cpu')
device

输出:

'cuda'

2、查看文件要分类的类型

import os, PIL, pathlib

data_dir = './data/'
data_dir = pathlib.Path(data_dir)   # 转换为 pathlib 类型

# 查看./data/下的文件夹名称
data_path = data_dir.glob('*')
classNames = [str(path).split('/')[1] for path in data_path]
classNames

输出:

['Early_blight', 'healthy', 'Late_blight']

3、展示图片

import matplotlib.pyplot as plt 
from PIL import Image 

# 获取图片路径名,展示早衰图片
data_look_dir = './data/Early_blight/'
data_path_list = [f for f in os.listdir(data_look_dir) if f.endswith(('JPG', 'png'))]

# 创建画板
fig, axes = plt.subplots(2, 8, figsize=(16, 4))

# 图片展示
for ax, img_name in zip(axes.flat, data_path_list):
    path_name = os.path.join(data_look_dir, img_name)
    img = Image.open(path_name)
    ax.imshow(img)
    ax.axis('off')
    
plt.show()


在这里插入图片描述

4、导入全部数据与数据处理

from torchvision import transforms, datasets 

# 图片处理--> 统一
data_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(   # 归一化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225] 
    )
])

# datasets.ImageFolder(): 专门加载图像分类的API
total_dir = './data/'
total_data = datasets.ImageFolder(root=total_dir, transform=data_transforms)

5、数据划分

train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size

# 随机划分
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

print(train_dataset)
print(test_dataset)
<torch.utils.data.dataset.Subset object at 0x7f17d1326e80>
<torch.utils.data.dataset.Subset object at 0x7f18e029db20>

6、动态加载数据

batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset, 
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=1)

test_dl = torch.utils.data.DataLoader(test_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=1)
# 展示一个加载后的数据格式
for X, data in test_dl:
    print("shape: ", X.shape)
    print("data: ", data)
    break
shape:  torch.Size([32, 3, 224, 224])
data:  tensor([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 2, 0, 1, 1, 1, 1, 1,
        0, 0, 1, 1, 0, 1, 2, 0])

2、VGG16模型的搭建

class vgg16_model(nn.Module):
    def __init__(self):
        super(vgg16_model, self).__init__()
        
        # block1
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # block2
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # block3
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # block4
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # block5
        self.block5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(in_features=512 * 7 * 7, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=3)
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        
        return x
model = vgg16_model().to(device)
model

输出:

vgg16_model(
  (block1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (block2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (block3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (block4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (block5): Sequential(
    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4096, out_features=4096, bias=True)
    (3): ReLU()
    (4): Linear(in_features=4096, out_features=3, bias=True)
  )
)

继续输出模型详细参数:

# 统计模型的参数以及其他指标
import torchsummary as summary
summary.summary(model, (3, 244, 244))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 244, 244]           1,792
       BatchNorm2d-2         [-1, 64, 244, 244]             128
              ReLU-3         [-1, 64, 244, 244]               0
            Conv2d-4         [-1, 64, 244, 244]          36,928
       BatchNorm2d-5         [-1, 64, 244, 244]             128
              ReLU-6         [-1, 64, 244, 244]               0
         MaxPool2d-7         [-1, 64, 122, 122]               0
            Conv2d-8        [-1, 128, 122, 122]          73,856
              ReLU-9        [-1, 128, 122, 122]               0
           Conv2d-10        [-1, 128, 122, 122]         147,584
             ReLU-11        [-1, 128, 122, 122]               0
        MaxPool2d-12          [-1, 128, 61, 61]               0
           Conv2d-13          [-1, 256, 61, 61]         295,168
             ReLU-14          [-1, 256, 61, 61]               0
           Conv2d-15          [-1, 256, 61, 61]         590,080
             ReLU-16          [-1, 256, 61, 61]               0
           Conv2d-17          [-1, 256, 61, 61]         590,080
             ReLU-18          [-1, 256, 61, 61]               0
        MaxPool2d-19          [-1, 256, 30, 30]               0
           Conv2d-20          [-1, 512, 30, 30]       1,180,160
             ReLU-21          [-1, 512, 30, 30]               0
           Conv2d-22          [-1, 512, 30, 30]       2,359,808
             ReLU-23          [-1, 512, 30, 30]               0
           Conv2d-24          [-1, 512, 30, 30]       2,359,808
             ReLU-25          [-1, 512, 30, 30]               0
        MaxPool2d-26          [-1, 512, 15, 15]               0
           Conv2d-27          [-1, 512, 15, 15]       2,359,808
             ReLU-28          [-1, 512, 15, 15]               0
           Conv2d-29          [-1, 512, 15, 15]       2,359,808
             ReLU-30          [-1, 512, 15, 15]               0
           Conv2d-31          [-1, 512, 15, 15]       2,359,808
             ReLU-32          [-1, 512, 15, 15]               0
        MaxPool2d-33            [-1, 512, 7, 7]               0
           Linear-34                 [-1, 4096]     102,764,544
             ReLU-35                 [-1, 4096]               0
           Linear-36                 [-1, 4096]      16,781,312
             ReLU-37                 [-1, 4096]               0
           Linear-38                    [-1, 3]          12,291
================================================================
Total params: 134,273,091
Trainable params: 134,273,091
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.68
Forward/backward pass size (MB): 316.39
Params size (MB): 512.21
Estimated Total Size (MB): 829.28
----------------------------------------------------------------

3、模型训练

1、模型训练搭建

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    batch_size = len(dataloader)
    
    train_acc, train_loss = 0, 0 
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        
        # 训练
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # 梯度下降法
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录
        train_loss += loss.item()
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        
    train_acc /= size
    train_loss /= batch_size
    
    return train_acc, train_loss

2、模型测试构建

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    batch_size = len(dataloader)
    
    test_acc, test_loss = 0, 0 
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
        
            pred = model(X)
            loss = loss_fn(pred, y)
        
            test_loss += loss.item()
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        
    test_acc /= size
    test_loss /= batch_size
    
    return test_acc, test_loss

3、设置超参数

loss_fn = nn.CrossEntropyLoss()  # 损失函数     
learn_lr = 1e-4             # 超参数
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr)   # 优化器

4、正式训练

train_acc = []
train_loss = []
test_acc = []
test_loss = []

epoches = 40

for i in range(epoches):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 输出
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')
    print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
    
print("Done")
Epoch: 1, Train_acc:45.4%, Train_loss:0.927, Test_acc:51.7%, Test_loss:0.909
Epoch: 2, Train_acc:63.6%, Train_loss:0.744, Test_acc:85.4%, Test_loss:0.491
Epoch: 3, Train_acc:83.4%, Train_loss:0.456, Test_acc:82.8%, Test_loss:0.423
Epoch: 4, Train_acc:87.3%, Train_loss:0.340, Test_acc:90.3%, Test_loss:0.276
Epoch: 5, Train_acc:89.8%, Train_loss:0.265, Test_acc:87.9%, Test_loss:0.294
Epoch: 6, Train_acc:89.8%, Train_loss:0.260, Test_acc:95.6%, Test_loss:0.136
Epoch: 7, Train_acc:95.4%, Train_loss:0.122, Test_acc:97.4%, Test_loss:0.087
Epoch: 8, Train_acc:95.9%, Train_loss:0.120, Test_acc:94.0%, Test_loss:0.164
Epoch: 9, Train_acc:97.4%, Train_loss:0.080, Test_acc:94.7%, Test_loss:0.117
Epoch:10, Train_acc:96.2%, Train_loss:0.102, Test_acc:97.2%, Test_loss:0.082
Epoch:11, Train_acc:95.0%, Train_loss:0.134, Test_acc:97.9%, Test_loss:0.083
Epoch:12, Train_acc:96.6%, Train_loss:0.093, Test_acc:96.8%, Test_loss:0.069
Epoch:13, Train_acc:98.2%, Train_loss:0.057, Test_acc:97.9%, Test_loss:0.063
Epoch:14, Train_acc:98.4%, Train_loss:0.059, Test_acc:98.4%, Test_loss:0.047
Epoch:15, Train_acc:98.4%, Train_loss:0.043, Test_acc:97.4%, Test_loss:0.058
Epoch:16, Train_acc:98.0%, Train_loss:0.059, Test_acc:95.4%, Test_loss:0.107
Epoch:17, Train_acc:97.5%, Train_loss:0.065, Test_acc:96.3%, Test_loss:0.088
Epoch:18, Train_acc:98.8%, Train_loss:0.037, Test_acc:98.1%, Test_loss:0.049
Epoch:19, Train_acc:98.6%, Train_loss:0.037, Test_acc:98.1%, Test_loss:0.049
Epoch:20, Train_acc:98.4%, Train_loss:0.043, Test_acc:97.9%, Test_loss:0.065
Epoch:21, Train_acc:99.5%, Train_loss:0.010, Test_acc:94.2%, Test_loss:0.300
Epoch:22, Train_acc:97.9%, Train_loss:0.069, Test_acc:93.7%, Test_loss:0.220
Epoch:23, Train_acc:99.4%, Train_loss:0.025, Test_acc:97.9%, Test_loss:0.085
Epoch:24, Train_acc:99.4%, Train_loss:0.019, Test_acc:97.0%, Test_loss:0.090
Epoch:25, Train_acc:98.8%, Train_loss:0.033, Test_acc:98.4%, Test_loss:0.060
Epoch:26, Train_acc:99.4%, Train_loss:0.018, Test_acc:97.0%, Test_loss:0.125
Epoch:27, Train_acc:98.9%, Train_loss:0.031, Test_acc:97.7%, Test_loss:0.091
Epoch:28, Train_acc:99.8%, Train_loss:0.005, Test_acc:97.4%, Test_loss:0.080
Epoch:29, Train_acc:99.8%, Train_loss:0.006, Test_acc:94.4%, Test_loss:0.360
Epoch:30, Train_acc:97.7%, Train_loss:0.058, Test_acc:97.9%, Test_loss:0.050
Epoch:31, Train_acc:98.6%, Train_loss:0.034, Test_acc:98.1%, Test_loss:0.065
Epoch:32, Train_acc:99.7%, Train_loss:0.014, Test_acc:98.1%, Test_loss:0.056
Epoch:33, Train_acc:99.4%, Train_loss:0.013, Test_acc:99.3%, Test_loss:0.047
Epoch:34, Train_acc:99.5%, Train_loss:0.020, Test_acc:98.6%, Test_loss:0.069
Epoch:35, Train_acc:99.4%, Train_loss:0.021, Test_acc:98.4%, Test_loss:0.055
Epoch:36, Train_acc:99.7%, Train_loss:0.010, Test_acc:98.1%, Test_loss:0.080
Epoch:37, Train_acc:99.7%, Train_loss:0.010, Test_acc:98.4%, Test_loss:0.059
Epoch:38, Train_acc:99.3%, Train_loss:0.017, Test_acc:98.4%, Test_loss:0.055
Epoch:39, Train_acc:99.3%, Train_loss:0.017, Test_acc:98.8%, Test_loss:0.043
Epoch:40, Train_acc:99.9%, Train_loss:0.004, Test_acc:98.8%, Test_loss:0.038
Done

我第一次、第二次跑的训练集准确率达到了100%,这是后面跑的,也正常,因为深度学习每一次更新的梯度、权重、偏置等都不一样。

4、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息

epochs_range = range(epoches)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()


在这里插入图片描述

  • 但从效果来看,这个以及够了,极好了,没什么必要提高了,但是从我初学的角度上来看,提高优化以下还是很有必要的

5、优化

两次优化都没啥提升,大差不差,有些反而提高了计算量

优化一

在每一层卷积层后添加BN层,提高稳定性,结果如下:

在这里插入图片描述

优化二

在分类层中,降低分类层数量,512*7*7 —> 1024 —> 512 --> 3,添加Dropout层,这个思路效果大差不大

6、轻量级优化

优化:全面修改全连接层,首先添加平均池化层,后添加Dropout层、Linear(512, 3),本来还想给添加L1正则化的,但是效果极差。

结果Train_acc:95.8%, Test_acc:94.9%但是,参数数量减少了约 99.9955%,效果极好

在这里插入图片描述

分析过程:

通过优化 VGG16 的分类层,您显著减少了模型的参数数量和计算复杂度。下面是对优化前后模型的对比分析:

优化前的分类层
self.classifier = nn.Sequential(
    nn.Linear(in_features=512 * 7 * 7, out_features=4096),
    nn.ReLU(),
    nn.Linear(in_features=4096, out_features=4096),
    nn.ReLU(),
    nn.Linear(in_features=4096, out_features=3)
)
参数计算
  1. 第一层nn.Linear(in_features=512 * 7 * 7, out_features=4096)
    • 输入大小:512 * 7 * 7 = 25088
    • 输出大小:4096
    • 参数数量:25088 * 4096 + 4096 = 103219200 + 4096 = 103223296
  2. 第二层nn.Linear(in_features=4096, out_features=4096)
    • 输入大小:4096
    • 输出大小:4096
    • 参数数量:4096 * 4096 + 4096 = 16781312 + 4096 = 16785408
  3. 第三层nn.Linear(in_features=4096, out_features=3)
    • 输入大小:4096
    • 输出大小:3
    • 参数数量:4096 * 3 + 3 = 12288 + 3 = 12291
总参数数量
  • 总参数数量:103223296 + 16785408 + 12291 = 120020995
优化后的分类层
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

self.classifier = nn.Sequential(
    nn.Dropout(),
    nn.Linear(512, 3)
)
参数计算
  1. 全局平均池化层nn.AdaptiveAvgPool2d((1, 1))
    • 不增加参数,只是将特征图的大小从 7x7 压缩到 1x1。
  2. 线性层nn.Linear(512, 3)
    • 输入大小:512
    • 输出大小:3
    • 参数数量:512 * 3 + 3 = 1536 + 3 = 1539
总参数数量
  • 总参数数量:1539
轻量化程度
  • 优化前的参数数量:120020995
  • 优化后的参数数量:1539
参数减少比例
  • 参数减少比例:(120020995 - 1539) / 120020995 ≈ 0.999955
  • 即优化后的模型参数数量减少了约 99.9955%。
计算复杂度
  • 优化前:涉及多个全连接层,计算量较大。
  • 优化后:只涉及一个全连接层,计算量显著减少。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2165513.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4. 数据结构: 对象和数组

数字、布尔值和字符串是构建数据结构的原子。不过&#xff0c;许多类型的信息需要不止一个原子。对象允许我们对值&#xff08;包括其他对象&#xff09;进行分组&#xff0c;从而构建更复杂的结构。到目前为止&#xff0c;我们所构建的程序都受到限制&#xff0c;因为它们只能…

MyBatis 入门教程-搭建入门工程

Maven作为一个优秀的项目构建和管理工具,在日常的开发中被大多数开发者使用,后续的项目也是基于Maven来构建。 创建一个Maven项目 利用IDEA创建项目工具来创建一个Maven项目 添加MyBatis的依赖 这里可以从Maven仓库地址中进行查看, https://mvnrepository.com/ 从这里可…

SUB1G无线通信模块赋能对讲机无线联网

一、模组介绍&#xff1a; ANS TKM-220是一款专为LPWAN物联网应用而研制的SUB1G无线模组&#xff0c;使用全新的TurMassTM 技术&#xff0c;具有超大容量 、高速率 、广覆盖和低成本的特点&#xff0c;处于国际领先水平 。 二、模组特点&#xff1a; ◉ 采用独创的TurMass™…

Oracle 19c 使用EMCC 监控当前所有数据库

一.EMCC简介 EMCC&#xff0c;全称Oracle Enterprise Manager Cloud Control&#xff0c;是Oracle提供的一套集中化监控工具&#xff0c;可以对数据库、操作系统、中间件等进行监控&#xff0c;通过OMS&#xff08;Oracle Management Service&#xff09;收集监控数据并将监控信…

Golang | Leetcode Golang题解之第421题数组中两个数的最大异或值

题目&#xff1a; 题解&#xff1a; const highBit 30type trie struct {left, right *trie }func (t *trie) add(num int) {cur : tfor i : highBit; i > 0; i-- {bit : num >> i & 1if bit 0 {if cur.left nil {cur.left &trie{}}cur cur.left} else …

C# 数据校验与控件绑定

在上一篇中&#xff0c;写了使用特性对一个对象的值进行校验&#xff1b;虽然代码比较简单&#xff0c;但依然不是最优解&#xff0c;在做数据新增校验的时候&#xff0c;倒也没什么问题&#xff0c;毕竟这是WinForm&#xff1b;但是如果是做数据编辑&#xff0c;代码就会变得更…

遗忘的数学(拉格朗日乘子法、牛顿法)

目录 拉格朗日乘子法定理 证明&#xff1a;​编辑 应用条件与符号选择 雅可比矩阵 黑塞矩阵 牛顿法 解方程的根的牛顿法 解方程组的根的牛顿法 数值优化的牛顿法&#xff08;求最值&#xff09; 拉格朗日乘子法定理 证明&#xff1a; dSi这一段没看懂…… 应用…

“AI+Security”系列第3期(四):360安全大模型业务实践

近日&#xff0c;由安全极客、Wisemodel 社区、InForSec 网络安全研究国际学术论坛和海升集团联合主办的“AI Security”系列第 3 期技术沙龙&#xff1a;“AI 安全智能体&#xff0c;重塑安全团队工作范式”活动顺利举行。此次活动吸引了线上线下超过千名观众参与。 活动中&…

C++——关联式容器(5):哈希表

7.哈希表 7.1 哈希表引入 哈希表的出现依旧是为了查找方便而设计的。在顺序结构中&#xff0c;查询一个值需要一一比较&#xff0c;复杂度为O(N)&#xff1b;在平衡树中&#xff0c;查询变为了二分查找&#xff0c;复杂度为O(logN)&#xff1b;而对于哈希表&#xff0c;我们可…

BST-二叉搜索树

前言 从图的角度出发&#xff0c;树是一种特殊的图。图的大多数算法&#xff0c;树都可以适用。对树操作中&#xff0c;你可以发现有关图算法思想的体现。 不过&#xff0c; 本篇不是完全从图的角度解读树&#xff0c; 重点在初学者视角&#xff08;一般学习数据结构顺序是从树…

码点和码元的区别--Unicode标准的【码点】和【码元】

Unicode是通用字符编码标准是计算机科学领域里的一项业界标准&#xff0c;包括字符集、编码方案等。 Unicode标准定义了一个统一的多语言文本字符集&#xff08;即Unicode字符集&#xff09;。 Unicode标准定义了三种字符编码方案&#xff1a;UTF-8、UTF-16、UTF-32。 因此&…

【Java面向对象高级06】static的应用知识:代码块

文章目录 前言一、代码块概述二、代码块分2种 1、静态代码块2、实例代码块总结 前言 记录static的应用知识&#xff1a;代码块 一、代码块概述 代码块是类的5大成分之一&#xff08;成员变量&#xff0c;构造器&#xff0c;方法&#xff0c;代码块&#xff0c;内部类&#xf…

「Python教程」vscode的安装和python插件下载

粗浅之言&#xff0c;如有错误&#xff0c;欢迎指正 文章目录 前言Python安装VSCode介绍VSCode下载安装安装python插件 前言 Python目前的主流编辑器有多个&#xff0c;例如 Sublime Text、VSCode、Pycharm、IDLE(安装python时自带的) 等。个人认为 vscode 虽然在大型项目上有…

一个好用的MP3音乐下载网,我推荐给你(免费)

点击访问->https://www.gequbao.com/ 或用Bing搜索歌曲宝即可。 主页面长这样子~ 以最近大火的悲鸣海为例&#xff0c;搜索&#xff1b; 以第一个为例&#xff0c;点击&#xff1b; 它既支持下载.mp3格式的音乐文件&#xff0c;还支持下载.lrc的歌词文件。 非常好用&…

使用ChatGPT引导批判性思维,提升论文的逻辑与说服力的全过程

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 批判性分析&#xff08;Critical Analysis&#xff09; 是论文写作中提升质量和说服力的重要工具。它不仅帮助作者深入理解和评价已有研究&#xff0c;还能指导作者在构建自己论点时更加…

网络工程师学习笔记——网络互连与互联网(三)

TCP三次握手 建立TCP连接是通过三次握手实现的&#xff0c;采用三报文握手主要是为了防止已失效的连接请求报文突然又传送到了&#xff0c;因而产生错误 主动发起TCP连接建立的称为客户端 被动等待的为TCP服务器&#xff0c;二者之间需要交换三个TCP报文段 首先是客户端主动…

jQuery——对象的使用

1、理解&#xff1a;即执行 jQuery 核心函数返回的对象 2、jQuery 对象内部包含的是 dom 元素对象的伪数组&#xff08;可能只有一个元素&#xff09; 3、jQuery 对象是一个包含所有匹配的任意多个 dom 元素的伪数组对象 4、基本行为&#xff1a; ① size&#xff08;&#xf…

Java_Se 数组与数据的存储

数组是相同类型数据的有序集合。其中&#xff0c;每一个数据称作一个元素&#xff0c;每个元素可以通过一个索引&#xff08;下标&#xff09;来访问它们。 数组的四个基本特点&#xff1a; 1.长度是确定的。数组一旦被创建&#xff0c;它的大小就是不可以改变的。 2.其元素…

【Java 问题】基础——面相对象

面向对象 15. 面向对象和面向过程的区别&#xff1f;16. 面向对象的基本特征17.重载&#xff08;overload&#xff09;和重写&#xff08;override&#xff09;的区别&#xff1f;18.访问修饰符public、private、protected、以及不写&#xff08;默认&#xff09;时的区别&…

2024低代码大赛火热进行,豪礼抢先看~

2024 网易低代码大赛正火热进行中&#xff0c;其中“网易云信低代码”专区吸引了众多开发者参与。 通过低代码高效、灵活的应用构建方式&#xff0c;结合云信的即时通讯和音视频能力&#xff0c;开发者既可以大幅缩短开发周期&#xff0c;还能提升应用的互动性和用户体验。 为…