VGG网络的代码实现

news2025/1/13 3:08:45

VGG网络的程序实现完全根据配置表来实现。

全连接层之前的部分属于特征提取部分,后三部分全连接层用来分类。

1、模型

import torch.nn as nn
import torch

# official pretrain weights
#预训练的权重下载地址
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}

#进行分类的代码
class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

#进行特征提取部分的代码
def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

定义了VGG11、VGG13、VGG16和VGG19。调用的时候只需要输入模型名字就可以。比如model=vgg("vgg16")

2、预处理

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])}

3、数据集及可视化

数据集使用的是眼睛疾病的数据集,分别包括四类:cataract(白内障)、diabetic_retinopathy(糖尿病性视网膜病变)、glaucoma(青光眼)、normal(正常)。

可视化:

代码:

    fig = plt.figure()
    for i in range(4):
        plt.subplot(1,4,i+1)
        # plt.tight_layout()
        # plt.imshow(test_image[i][0],cmap='CMRmap', interpolation='none')
        plt.imshow(test_image[i][0])
        # plt.title("Ground Truth: {}".format(test_label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

输出:

4、加载数据:

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

5、加载模型

model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=5, init_weights=True)
net.to(device)

6、损失函数:

loss_function = nn.CrossEntropyLoss()

7、优化器:

optimizer = optim.Adam(net.parameters(), lr=0.0001)

8、迁移学习

 #迁移学习
    model_weight_path = "./vgg16-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'),False)
    for param in net.parameters():
        param.requires_grad = False
    n_inputs=net.classifier[6].in_features
    last_layer=nn.Linear(n_inputs,4)
    net.classifier[6]=last_layer

9、训练:

    epochs = 30
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

结果:

对vgg11、vgg13、vgg16、vgg19分别进行测试。

vgg模型训练的时间都比较长,从损失值看vgg19效果好一些。从精度上看,vgg13、vgg11、vgg19都有不错的精确率。

完整代码:

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

参考资料:

在Pytorch中使用VGG16进行迁移学习-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1519699.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用opencv进行图片分析

opencv学习 一、配置环境并打开编译器 配置opencv在你的任意一个盘里创建一个专属于opencv的文件夹便于学习与整理 打开控制台winr输入cmd,进入后输入conda activate opencv,进入环境以后进入你所设置的opencv文件的盘,我的是D盘&#xff0…

fastreport循环数据表

1.创建数据源 2.将数据源关联到数据区 3.配置控件及属性 拖拽文本控件,设置文本控件的style属性为Data 4.比对效果 比对数据库和报表的数据一致。且循环显示。 数据库数据 报表展示数据

【Qt】从QMainWindow到UI框架

目录 简介UI布局元素Central WidgetMenu BarToolbarsStatus BarDock Widgets 参考文档 简介 如下图所示,我们常见的一些desktop软件,比如VS Code、Smart VCI等,一般都会包含顶部的菜单栏,底部的状态栏,以及一些其他UI…

工业制造领域系统:SCADA、PLC、DCS、MES、HMI、ERP等,一文秒懂

工业制造控制系统在工业制造领域起到了关键的作用,帮助企业提高生产效率、降低成本、提高产品质量和安全性。不同的企业根据自身需求和规模,可能会选择使用其中的一种或多种系统。 SCADA系统(Supervisory Control and Data Acquisition&…

第2章 信息技术基础

本章学习要点 全面了解医院信息系统建设所涉及的主要信息技术以及这些技术的应用情况。 计算机与网络、信息技术与信息系统、数字媒体与数据存储技术、条形码(二维码)、RFID技术、云计算、APP技术 1.XML 可扩展标记语言与Access,Oracle和SQL Server等数据库不同…

什么软件可以剪辑录音?录音剪辑推荐3款工具

随着数字技术的发展,录音已经成为我们日常生活和工作中不可或缺的一部分。无论是会议记录、课堂笔记,还是音乐创作、语音聊天,我们都需要用到录音功能。然而,单纯的录音往往不能满足我们的需求,我们还需要对录音进行剪…

常用芯片学习——TP4057电源管理芯片

TP40578 500mA线性锂离子电池充电器 芯片介绍 TP4057是一款性能优异的单节锂离子电池恒流/恒压线性充电器。TP4057采用S0T23-6封装配合较少的外围原件使其非常适用于便携式产品,并且适合给USB电源以及适配器电源供电。 基于特殊的内部MOSFET架构以及防倒充电路&a…

图像处理与视觉感知---期末复习重点(3)

文章目录 一、空间域和频率域二、傅里叶变换三、频率域图像增强 一、空间域和频率域 1. 空间域:即所说的像素域,在空间域的处理就是在像素级的处理,如在像素级的图像叠加。通过傅立叶变换后,得到的是图像的频谱,表示图…

BMP280学习

1.Forced mode模式&#xff0c;单次采集后进入休眠&#xff0c;适用于低采样率。 2.normal mode模式&#xff0c;持续采集&#xff0c;我们使用这种 采集事件基本都是ms级&#xff0c;所以我们1s更新一次。 温度和压力的计算 #include <SPI.h> //定义数据类型 #define s3…

就业班 2401--3.12 Linux Day16 PXE布置——自动化装系统

什么是PXE&#xff1f; PXE&#xff0c;全名Pre-boot Execution Environment&#xff0c;预启动执行环境&#xff1b;通过网络接口启动计算机&#xff0c;不依赖本地存储设备&#xff08;如硬盘&#xff09;或本地已安装的操作系统&#xff1b;由Intel和Systemsoft公司于1999年…

2024/03/15(网络编程·day3)

一、思维导图 二、模拟面试题 什么是IP地址&#xff1f; IP地址是主机在网络中的唯一标识。 分为IPv4和IPv6&#xff0c; IPv4由4字节32位二进制数组成&#xff0c;通常使用点分十进制表示&#xff0c;例如192.168.117.85 &#xff0c;其中每个十进制数的范围都在0-255. IPv6由…

技术资讯:Volar正式更名为Vue-Official

大家好&#xff0c;我是大澈&#xff01; 本文约700字&#xff0c;整篇阅读大约需要1分钟。 关注微信公众号&#xff1a;“程序员大澈”&#xff0c;免费加入问答群&#xff0c;一起交流技术难题与未来&#xff01; 现在关注公众号&#xff0c;免费送你 ”前后端入行大礼包“…

2024年【广东省安全员C证第四批(专职安全生产管理人员)】考试总结及广东省安全员C证第四批(专职安全生产管理人员)模拟试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 广东省安全员C证第四批&#xff08;专职安全生产管理人员&#xff09;考试总结是安全生产模拟考试一点通总题库中生成的一套广东省安全员C证第四批&#xff08;专职安全生产管理人员&#xff09;模拟试题&#xff0c;…

ArrayBlockingQueue与LinkedBlockingQueue底层原理

ArrayBlockingQueue与LinkedBlockingQueue底层原理 在线程池中&#xff0c;等待队列采用ArrayBlockingQueue或者LinkedBlockingDeque&#xff0c;那他们是怎么实现存放线程、阻塞、取出的呢&#xff1f; 一、ArrayBlockingQueue底层原理 1.1 简介 ArrayBlockingQueue是一个阻塞…

[嵌入式AI从0开始到入土]16_ffmpeg_ascend编译安装及性能测试

[嵌入式AI从0开始到入土]嵌入式AI系列教程 注&#xff1a;等我摸完鱼再把链接补上 可以关注我的B站号工具人呵呵的个人空间&#xff0c;后期会考虑出视频教程&#xff0c;务必催更&#xff0c;以防我变身鸽王。 第1期 昇腾Altas 200 DK上手 第2期 下载昇腾案例并运行 第3期 官…

力扣题目训练(21)

2024年2月14日力扣题目训练 2024年2月14日力扣题目训练605. 种花问题617. 合并二叉树628. 三个数的最大乘积289. 生命游戏299. 猜数字游戏149. 直线上最多的点数 2024年2月14日力扣题目训练 2024年2月14日第二十一天编程训练&#xff0c;今天主要是进行一些题训练&#xff0c;…

2.2 HTML5保留的常用标签

2.2.1 基础标签 1. 段落标签<p> 段落标签<p>和</p>用于形成一个新的段落&#xff0c;段落与段落之间默认为空一行进行分割。 2. 标题标签<h1>-<h6> HTML5使用<hn>和</hn>来标记文本中的标题&#xff0c;其中n需要替换为数字&#x…

关于OPC-UA客户端调用服务端方法CallMethod节点的问题

在OpcUaClient中可以通过CallMethodByNodeId调用方法节点 //// 摘要:// call a server method//// 参数:// tagParent:// 方法的父节点tag//// tag:// 方法的节点tag//// args:// 传递的参数//// 返回结果:// 输出的结果值public object[] CallMetho…

淘宝商品详情数据采集(商品属性,规格,价格,详情图等)

淘宝商品详情数据采集是一个涉及多个步骤的过程&#xff0c;主要目的是获取商品的各种详细信息&#xff0c;如商品属性、规格、价格、详情图等。以下是一个基本的采集流程&#xff1a; 确定采集目标&#xff1a;首先&#xff0c;需要明确要采集的淘宝商品范围&#xff0c;例如…

Filter实现请求日志记录

将锁有得外部访问都记录在日志文件里面&#xff0c;设计这个功能是为了&#xff08;为什么&#xff09;&#xff1a; 1. 在不引入Promentheus进行接口监控时&#xff0c;基于日志文件就可以实现整个项目得监控。 2. 当出现问题时&#xff0c;可以基于此进行流量重放。 效果如…