分类网络搭建示例

news2024/12/25 22:35:37

搭建CNN网络

本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。

请注意,这里定义的只是一个简陋的版本,后续一些经典网络的学习,我们会在另外单独去开一个专栏讲解。

1. 网络搭建

在PyTorch中,你可以使用 torchvision.models 中的 vgg16 来加载预定义的VGG16模型,也可以手动定义。以下是手动定义的一个简化版本:

import torch
import torch.nn as nn

class VGG16(nn.Module):
    def __init__(self, num_classes=1000):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

2. 初始化方法

在这里,我们不再手动初始化每一层,因为PyTorch的默认初始化通常足够好。你可以选择手动初始化,如果需要,可以使用 torch.nn.init 中的不同方法。

3. 模型的保存

使用 torch.save 保存VGG16模型:

vgg16 = VGG16()

torch.save(vgg16.state_dict(), 'vgg16_model.pth')

4. 预训练模型的加载

要加载预训练的VGG16模型,你可以使用 torchvision.models 中的 vgg16(pretrained=True),或者手动加载预训练权重:

vgg16 = VGG16()

vgg16.load_state_dict(torch.load('pretrained_vgg16.pth'))

请确保路径 'pretrained_vgg16.pth' 是你预训练模型文件的实际路径。你可以从PyTorch的官方模型库或其他来源下载预训练权重。

上面是最简单的一种模型全部加载的方式,但也有一些情况下,只是想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择随机初始化。 、

这里我们来观察网络怎么去表示的:

if __name__ == "__main__":
    model = VGG16()
    for name, value in model.named_parameters():
        print(name)

下面就是控制台打印出的部分信息。 

这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。

# 加载预训练 VGG16 模型的参数
pretrained_dict = torch.load('pretrained_vgg16.pth')

# 剔除预训练模型中全连接层的参数
pretrained_dict.pop('classifier.0.weight')
pretrained_dict.pop('classifier.0.bias')
pretrained_dict.pop('classifier.3.weight')
pretrained_dict.pop('classifier.3.bias')
pretrained_dict.pop('classifier.6.weight')
pretrained_dict.pop('classifier.6.bias')

# 获取自定义模型的参数字典
model_dict = model.state_dict()

# 更新自定义模型的参数字典,加载预训练模型的参数值
model_dict.update(pretrained_dict)

# 加载更新后的参数字典到自定义模型中
model.load_state_dict(model_dict)

自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了 pretrained_dict中有的层。

总结

本章只是对网络的定义进行一个简单的示例,具体的部分我们会在另外一个专栏讲解,这里只是为了让读者了解网络定义的流程。在实际项目中,通常需要更详细的网络结构,包括适当的初始化方法、损失函数的选择、优化器的设置等。如果读者了解掌握了基本的网络定义过程,你可以在本专栏中深入讲解这些方面,以及如何训练和评估模型等内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1199622.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于猕猴感觉运动皮层的神经元Spike信号分析

公开数据集中文版详细描述参考前文:https://editor.csdn.net/md/?not_checkout1&spm1011.2124.3001.6192 目录 0. 公开数据集1. 神经元的raster和PSTH图1.1 Raster1.2 PSTH 2. 运动轨迹图 (center_out)3. 神经元的运动调制曲线 (tuning curve) 0. 公开数据集 …

Leetcode100120. 找出强数对的最大异或值 I

Every day a Leetcode 题目来源:100120. 找出强数对的最大异或值 I 解法1:模拟 枚举 2 遍数组 nums 的元素,更新最大异或值。 代码: /** lc appleetcode.cn id100120 langcpp** [100120] 找出强数对的最大异或值 I*/// lc c…

火爆进行中的抖音双11好物节,巨量引擎助5大行业商家开启爆单之路!

抖音双11好物节目前正在火热进行中,进入爆发期,各大商家“好招”频出,都想要实现高速增长。依托“人群、货品、流量”三大优势,巨量引擎一直都是商家生意增长的给力伙伴,在今年的抖音双11好物节,巨量引擎就…

SparkSQL之Catelog体系

按照SQL标准的解释,在SQL环境下Catalog和Schema都属于抽象概念。在关系数据库中,Catalog是一个宽泛的概念,通常可以理解为一个容器或数据库对象命名空间中的一个层次,主要用来解决命名冲突等问题。 在Spark SQL系统中,…

Django基础介绍及HTTP请求

文章目录 Django框架的介绍Django的安装 Django框架开发创建项目的指令Django项目的目录结构URL 介绍视图函数(view)Django 中的路由配置带有分组的路由和视图函数带有命名分组的路由和视图函数 HTTP协议的请求和响应HTTP 请求HTTP 响应GET方式传参POST传递参数form 表单的name…

泉峰控股发布业务白皮书, 释放中国芯片企业要发展成全球领先的信心与决心

近日,多元化芯片全球供应商泉峰控股发布了一份题为《致力于中国芯片产业自立自强》的业务白皮书。白皮书系统阐述了泉峰控股企业发展策略和业务规划,充分体现出中国芯片企业要在全球范围内实现技术突破、市场扩张的信心与决心。 白皮书首先分析了当前全球芯片产业的发展态势。在…

leetcode(力扣) 51. N 皇后 (回溯,纸老虎题)

文章目录 题目描述思路分析对于问题1对于问题2 完整代码 题目描述 按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数…

业务出海之服务器探秘

这几年随着国内互联网市场的逐渐饱和,越来越多的公司加入到出海的行列,很多领域都取得了很不错的成就。虽然出海可以获得更加广阔的市场,但也需要面对很多之前在国内可能没有重视的一些问题。集中在海外服务器的选择维度上就有很大的变化。例…

rocksdb中测试工具Benchmark.sh用法(基准、性能测试)

1.首先要安装db_bench工具,这个工具在成功安装rocksdb之后就自动存在了,主要是在使用make命令之后就成功安装了,详情请见我之前的文章 2.确保成功安装db_bench之后,找到安装的rocksdb目录下面的tools文件夹,查看里面是…

怎么改变容易紧张的性格?

容易紧张的性格是比较通俗的说法,在艾森克人格测试中,容易紧张的性格就属于神经症人格,神经质不是神-经-病,而是一种人格特征,这种特征包括:敏感,情绪不稳定,易焦虑和紧张。有兴趣的…

(一)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB

一、七种算法(DBO、LO、SWO、COA、LSO、KOA、GRO)简介 1、蜣螂优化算法DBO 蜣螂优化算法(Dung beetle optimizer,DBO)由Jiankai Xue和Bo Shen于2022年提出,该算法主要受蜣螂的滚球、跳舞、觅食、偷窃和繁…

SpringBoot 监控

概述 SpringBoot自带监控功能Actuator&#xff0c;可以帮助实现对程序内部运行情况监控&#xff0c;比如监控状况、Bean加载情况、配置属性、日志信息等。 使用步骤 导入依赖坐标 <dependency><groupId>org.springframework.boot</groupId><artifactI…

HCIA-经典综合实验(一)

经典综合实验&#xff08;一&#xff09; 实验拓扑配置步骤第一步&#xff1a;配置二层VLAN第二步&#xff1a;配置IP地址第三步&#xff1a;配置DHCP服务第四步&#xff1a;配置路由协议OSPF第五步&#xff1a;配置ACLNATTelnet 配置验证测试PC1能不能telnet登录到R1测试所有P…

Leetcode—765.情侣牵手【困难】

2023每日刷题&#xff08;二十七&#xff09; Leetcode—765.情侣牵手 并查集置换环思路 参考自ylb 实现代码 class Solution { public:int minSwapsCouples(vector<int>& row) {int n row.size();int len n / 2;vector<int> p(len);iota(p.begin(), p.…

postman连接数据库

参考&#xff1a;https://blog.csdn.net/qq_45572452/article/details/126620210 1、安装node.js 2、配置环境变量 3、安装xmysql连接数据库cmd窗口输入"npm install -g xmysql"后回车cmd窗口输入"xmysql"后回车,验证xmysql是否安装成功(下图代表安装成功)…

算法通关村第八关-白银挑战二叉树的深度和高度问题

大家好我是苏麟 , 今天说说几道二叉树深度和高度相关的题目 . LeetCode给我们造了一堆的题目&#xff0c;研究一下104、110和111三个题&#xff0c;这三个颗看起来挺像的&#xff0c;都是关于深度、高度的。 最大深度问题 描述 : 二叉树的 最大深度 是指从根节点到最远叶子…

一文读懂微前端

1 语雀文档 https://www.yuque.com/chanwj/vlkwxk/qvpv3kqws5hno3qt?singleDoc# 《微前端》本文使用的参考文档均以链接方式粘贴于文章内&#xff0c;十分感谢~ 2 项目github链接 如果你觉得本文档对你有用&#xff0c;恳请github仓库给个star~https://github.com/OmegaCh…

在ant构建脚本中调用maven的命令

有时候想用maven管理依赖&#xff0c;用ant构建。 在ant的build.xml文件中可以使用exec这个task来调用系统命令&#xff0c;也就可以调用maven的命令。 例如&#xff0c;执行maven的命令mvn dependency:copy-dependencies&#xff0c;可以将项目的依赖提取出来&#xff0c;放…

菜单栏管理软件 Bartender 3 mac中文版功能介绍

​Bartender 3 mac是一款菜单栏管理软件&#xff0c;该软件可以将指定的程序图标隐藏起来&#xff0c;需要时呼出即可。 Bartender 3 mac功能介绍 Bartender 3完全支持macOS Sierra和High Sierra。 更新了macOS High Sierra的用户界面 酒吧现在显示在菜单栏中&#xff0c;使其…

TMSRL

Z是学到的子空间表征 辅助信息 作者未提供代码