VGG(pytorch)

news2024/9/21 4:00:14

VGG:达到了传统串型结构深度的极限

学习VGG原理要了解CNN感受野的基础知识

model.py

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        #features参数是特征层模型,传入这个参数直接使用构造的特征层模型
        self.features = features
        self.classifier = nn.Sequential(
            
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_features(cfg: list):
    layers = []
    in_channels = 3

    #传入参数cfg是一个列表,遍历参数列表构造VGG特征层
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)
    #特征层函数返回一个nn.Sequential(*layers),
    #这段代码中的 return nn.Sequential(*layers) 使用了 nn.Sequential 类来创建一个神经网络模型。
    # 在这里,layers 是一个可迭代对象,包含了神经网络模型的各个层或模块。
    #这段代码的作用是封装一个神经网络模型,该模型按照 layers 中层或模块的顺序连接起来,并作为 nn.Sequential 对象返回。

cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}



#在函数定义中的 **kwargs 是一个特殊的参数形式,它允许函数接受任意数量的关键字参数(keyword arguments)。
# 这个参数形式使用了双星号 ** 来表示。
#在上述代码中,**kwargs 的作用是允许函数 vgg() 接受额外的关键字参数,并将这些参数收集到 kwargs 字典中
#如vgg(model_name="vgg16", num_classes=10, pretrained=True) pretrained就是一个**kwargs参数
def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    #定义一个数据加载器用于迭代提取数据
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

这里由于训练时间太长,运行了19个epoch中断。结果如下

using cuda:0 device.
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
train epoch[1/30] loss:1.542: 100%|██████████| 104/104 [08:39<00:00,  4.99s/it]
100%|██████████| 12/12 [01:13<00:00,  6.15s/it]
[epoch 1] train_loss: 1.605  val_accuracy: 0.245
train epoch[2/30] loss:1.399: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.12s/it]
[epoch 2] train_loss: 1.476  val_accuracy: 0.401
train epoch[3/30] loss:1.310: 100%|██████████| 104/104 [08:34<00:00,  4.94s/it]
100%|██████████| 12/12 [01:18<00:00,  6.53s/it]
[epoch 3] train_loss: 1.293  val_accuracy: 0.456
train epoch[4/30] loss:0.958: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.11s/it]
[epoch 4] train_loss: 1.185  val_accuracy: 0.519
train epoch[5/30] loss:1.327: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.11s/it]
[epoch 5] train_loss: 1.135  val_accuracy: 0.527
train epoch[6/30] loss:1.209: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.12s/it]
[epoch 6] train_loss: 1.077  val_accuracy: 0.571
train epoch[7/30] loss:0.725: 100%|██████████| 104/104 [1:25:27<00:00, 49.30s/it]
100%|██████████| 12/12 [01:21<00:00,  6.82s/it]
[epoch 7] train_loss: 1.051  val_accuracy: 0.596
train epoch[8/30] loss:1.146: 100%|██████████| 104/104 [08:50<00:00,  5.10s/it]
100%|██████████| 12/12 [01:27<00:00,  7.31s/it]
[epoch 8] train_loss: 1.008  val_accuracy: 0.615
train epoch[9/30] loss:1.381: 100%|██████████| 104/104 [08:48<00:00,  5.08s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 9] train_loss: 0.995  val_accuracy: 0.640
train epoch[10/30] loss:0.466: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 10] train_loss: 0.966  val_accuracy: 0.673
train epoch[11/30] loss:0.867: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.13s/it]
[epoch 11] train_loss: 0.926  val_accuracy: 0.659
train epoch[12/30] loss:0.804: 100%|██████████| 104/104 [08:34<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 12] train_loss: 0.916  val_accuracy: 0.665
train epoch[13/30] loss:0.377: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 13] train_loss: 0.879  val_accuracy: 0.648
train epoch[14/30] loss:0.588: 100%|██████████| 104/104 [08:35<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.16s/it]
[epoch 14] train_loss: 0.841  val_accuracy: 0.676
train epoch[15/30] loss:0.725: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.13s/it]
[epoch 15] train_loss: 0.830  val_accuracy: 0.687
train epoch[16/30] loss:0.977: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 16] train_loss: 0.811  val_accuracy: 0.720
train epoch[17/30] loss:0.923: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 17] train_loss: 0.796  val_accuracy: 0.703
train epoch[18/30] loss:1.150: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.15s/it]
[epoch 18] train_loss: 0.794  val_accuracy: 0.720
train epoch[19/30] loss:0.866:  19%|█▉        | 20/104 [01:54<07:59,  5.71s/it]

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

预测结果

class: daisy        prob: 0.00207
class: dandelion    prob: 0.00144
class: roses        prob: 0.101
class: sunflowers   prob: 0.00535
class: tulips       prob: 0.89

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316428.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ArrayList与LinkLIst

ArrayList 在Java中&#xff0c;ArrayList是java.util包中的一个类&#xff0c;它实现了List接口&#xff0c;是一个动态数组&#xff0c;可以根据需要自动增长或缩小。下面是ArrayList的一些基本特性以及其底层原理的简要讲解&#xff1a; ArrayList基本特性&#xff1a; 动…

前端框架的虚拟DOM(Virtual DOM)

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

JVM学习之类加载子系统

类加载子系统 类加载子系统负责从文件或者网络中加载Class文件&#xff0c;class文件在开头有特定的标识 ClassLoader只负责class文件的加载&#xff0c;是否可运行是执行引擎决定的 加载的类信息放在方法区。除了类信息之外&#xff0c;方法区也会放运行时常量池&#xff0c…

漏刻有时数据可视化Echarts组件开发(43)纹理填充和HTMLImageElement知识说明

在 ECharts 中&#xff0c;纹理填充可以通过自定义系列&#xff08;series&#xff09;的 itemStyle 属性来实现。itemStyle 属性用于设置系列中每个数据项的样式&#xff0c;包括填充颜色、边框颜色、边框线宽等。 纹理填充 // 纹理填充 {image: imageDom, // 支持为 HTMLIm…

VM虚拟机打不开原来保存的虚拟机文件夹ubuntu

VMWare虚拟机打不开原来保存的虚拟机文件夹ubuntu 换了电脑把之前的虚拟机克隆的文件夹直接拿来用 报这个错&#xff1a; 指定的文件不是虚拟磁盘 打不开磁盘“D:\ubuntu_iso\ubuntu_location\Ubuntu 64 位-s002.vmdk”或它所依赖的某个快照磁盘。 模块“Disk”启动失败。 未…

HiveSql语法优化三 :join优化

前面提到过&#xff1a;Hive拥有多种join算法&#xff0c;包括Common Join&#xff0c;Map Join&#xff0c;Bucket Map Join&#xff0c;Sort Merge Buckt Map Join等&#xff1b;每种join算法都有对应的优化方案。 Map Join 在优化阶段&#xff0c;如果能将Common Join优化为…

PAT 乙级 1008 数组元素循环右移问题

解题思路:这种循环题有一个经典的O(N)解法&#xff0c;就是前后对称交换&#xff0c;举例&#xff0c;我要循环右移 123456 的后俩个&#xff0c;我们的算法是将56&#xff0c;变成65&#xff0c;把前面的1234变成4321,然后将432165 对称交换就变成了561234 c语言代码如下&…

【Proteus仿真】【51单片机】电子称重秤

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使LCD1602液晶&#xff0c;矩阵按键、蜂鸣器、HX711称重模块等。 主要功能&#xff1a; 系统运行后&#xff0c;LCD1602显示HX711称重模块检测重量&…

Python基础06-异常

零、文章目录 Python基础06-异常 1、异常的基本概念 &#xff08;1&#xff09;异常是什么 当检测到一个错误时&#xff0c;解释器就无法继续执行了&#xff0c;反而出现了一些错误的提示&#xff0c;这就是所谓的"异常"。 &#xff08;2&#xff09;异常演示 …

持续集成交付CICD:Jenkins使用GitLab共享库实现基于SaltStack的CD流水线部署前后端应用

目录 一、实验 1.Jenkins使用GitLab共享库实现基于SaltStack的CD流水线部署前后端应用 2.优化共享库代码 二、问题 1.Jenkins手动构建后端项目流水线报错 一、实验 1.Jenkins使用GitLab共享库实现基于SaltStack的CD流水线部署前后端应用 &#xff08;1&#xff09;GitLa…

MySQL,分组order by

一、创建分组 ## 创建分组 -- 返回每个发布会的参会人数 SELECT event_id,COUNT(*) as canjia_num FROM sign_guest GROUP BY event_id; 1、group by子句可以包含任意个列&#xff0c;但是但指定的所有列都是一起计算的。 group by 后2个字段一起计算的 2、group by后面可以跟…

Kafka-日志索引

Kafka的Log日志梳理 Topic下的消息是如何存储的&#xff1f; 在搭建Kafka服务时&#xff0c;在server.properties配置文件中通过log.dir属性指定了Kafka的日志存储目录。 实际上&#xff0c;Kafka的所有消息就全都存储在这个目录下。 这些核心数据文件中&#xff0c;.log结尾…

某60内网渗透之frp实战指南2

内网渗透 文章目录 内网渗透frp实战指南2实验目的实验环境实验工具实验原理实验内容frp实战指南2 实验步骤(1)确定基本信息。(2)查看frp工具的基本用法(3)服务端frp的配置(4)客户端frp的配置(5)使用frp服务 frp实战指南2 实验目的 让学员通过该系统的练习主要掌握&#xff1a…

方差分析实例

目录 方差分析步骤 相关概念 基本思想 随机误差 系统误差 组内方差 组间方差 方差的比较 方差分析的前提 1.每个总体都应服从正态分布 2.各个总体的方差必须相同 3.观察值是独立的 原假设成立 备择假设成立 单因素方差分析 提出假设 检验的统计量 水平的均值…

云原生之深入解析Linkerd Service Mesh的功能和使用

一、简介 Linkerd 是 Kubernetes 的一个完全开源的服务网格实现&#xff0c;它通过为你提供运行时调试、可观测性、可靠性和安全性&#xff0c;使运行服务更轻松、更安全&#xff0c;所有这些都不需要对代码进行任何更改。Linkerd 通过在每个服务实例旁边安装一组超轻、透明的…

【卡塔尔世界杯数据可视化与新闻展示】

卡塔尔世界杯数据可视化与新闻展示 前言数据获取与处理可视化页面搭建功能实现新闻信息显示详情查看登录注册评论信息管理 创新点结语 前言 随着卡塔尔世界杯的临近&#xff0c;对于足球爱好者来说&#xff0c;对比赛的数据分析和新闻报道将成为关注的焦点。本文将介绍如何使用…

Ubuntu安装蓝牙模块pybluez以及问题解决方案【完美解决】

文章目录 简介问题及解决办法总结 简介 近期因工程需要在Ubuntu中使用蓝牙远程一些设备。安装Bluetooth的Python第三方软件包pybluez时遇到很多问题&#xff0c;一番折腾后完美解决。此篇博客进行了梳理和总结&#xff0c;供大家参考。 问题及解决办法 pip install pybluez安…

nodejs微信小程序+python+PHP技术下的音乐推送系统-计算机毕业设计推荐

音乐推送系统采取面对对象的开发模式进行软件的开发和硬体的架设&#xff0c;能很好的满足实际使用的需求&#xff0c;完善了对应的软体架设以及程序编码的工作&#xff0c;采取MySQL作为后台数据的主要存储单元&#xff0c;  本文设计了一款音乐推送系统&#xff0c;系统为人…

解决vue3+ts打包,ts类型检查报错导致打包失败

最近拉的开源大屏项目goview&#xff0c;在打包的过程中一直报Ts类型报错导致打包失败&#xff0c;项目的打包命令为&#xff1a; "build": "vue-tsc --noEmit && vite build" 是因为 vue-tsc --noEmit 是 TypeScript 编译器&#xff08;tsc&#…

054:vue工具 --- BASE64加密解密互相转换

第054个 查看专栏目录: VUE ------ element UI 专栏目标 在vue和element UI联合技术栈的操控下&#xff0c;本专栏提供行之有效的源代码示例和信息点介绍&#xff0c;做到灵活运用。 &#xff08;1&#xff09;提供vue2的一些基本操作&#xff1a;安装、引用&#xff0c;模板使…