深度学习12—VGG19实现

news2024/9/30 3:24:24

目录

VGG19实现

1.为数据打标签的generate_txt.py

2.对图像进行预处理的data_process.py

3.VGG19的网络构建代码net_VGG19.py

4.训练得到pth模型参数文件的get_pth_file.py

5.预测代码predict.py

6.预测VGG16与VGG19结果对比


VGG19实现

1.为数据打标签的generate_txt.py

这里的程序设计思想还是可以学一下的。

import os
from os import getcwd # 文件夹操作

# 写入数据集对应的文件夹
classes = ['cat','dog']
sets = ['train']

# 主程序执行
if __name__ == "__main__":
    wd = getcwd() # 获取当前工作目录

    # 说明:当前代码的目录关系是:sets-->subset(当前只有一个“train”)-->type_name(“train”下的关于类别的文件夹,比如“dog”)-->具体样本数据
    # 遍历sets中的每个文件夹,当前sets中只有"train"一个文件夹
    for subset in sets:
        list_file = open("cls_"+subset+".txt",'w')

        # 拿到每个子文件夹的目录
        path_subset = subset
        type_names = os.listdir(path_subset) # 拿到subset文件夹下的所有动物分类文件夹type_name,存到type_names列表中

        # 遍历subset中的每个文件夹type_name
        """
        它遍历名为type_names的列表中的每个元素。
        代码的目的是检查每个元素是否存在于名为classes的集合中。
        如果存在,代码会继续执行下一次循环,处理下一个元素;
        如果不存在,代码会跳过当前循环并继续执行下一次循环。
        """
        for type_name in type_names:
            if type_name not in classes:
                continue

            # 打标签
            type_id = classes.index(type_name) # 按type_name文件夹在classes文件夹中的索引,为type_name编号

            # 生成每个type_name文件夹的路径
            type_path = os.path.join(path_subset,type_name)
            photo_names = os.listdir(type_path) # 拿到type_name文件夹下的所有图片,组成一个列表phto_names

            # 处理每一张图片
            """
            这段代码的作用是遍历名为photos_name的列表中的每个元素,
            并根据文件名的扩展名来过滤文件。代码会判断文件的扩展名是否为.jpg、.png或.jpeg,
            如果不是这些扩展名之一,则跳过当前文件的处理。对于符合条件的文件,
            代码会将其写入到名为list_file的文件中,并写入文件的类别ID和路径信息。
            """
            for photo_name in photo_names:

                """
                这一行代码使用os.path.splitext()函数将文件名photo_name分成文件名部分和扩展名部分,
                并将扩展名赋值给变量postfix。下划线_表示不使用文件名部分,只关注扩展名。
                """
                _,postfit = os.path.splitext(photo_name) # #该函数用于分离文件名与拓展名

                # 如果拓展名不在如下的列表中,则跳过当前循环;如果在,则继续
                if postfit not in ['.jpg','.png','.jpeg']:
                    continue

                # 将文件的类别ID和完整路径信息写入到名为list_file的文件中
                photo_path = os.path.join(type_path,photo_name)
                # print(wd) # C:\Users\ZARD\PycharmProjects\pythonProject\AAA_FX\revise_VGG19
                # 如上可知,wd为该项目的路径
                list_file.write(str(type_id)+';'+'%s/%s'%(wd,photo_path))
                list_file.write('\n') # 这一行代码写入一个换行符,将下一个文件的记录写入到新的一行

        list_file.close()

2.对图像进行预处理的data_process.py

对数据做一些基本操作,可根据实际需求进行更改。

import cv2
import numpy as np
import torch.utils.data as data
from PIL import  Image

def preprocess_input(x):
    x/=127.5
    x-=1.
    return x
def cvtColor(image):
    if len(np.shape(image))==3 and np.shape(image)[-2]==3:
        return image
    else:
        image=image.convert('RGB')
        return image


class DataGenerator(data.Dataset):
    def __init__(self,annotation_lines,inpt_shape,random=True):
        self.annotation_lines=annotation_lines
        self.input_shape=inpt_shape
        self.random=random

    def __len__(self):
        return len(self.annotation_lines)
    def __getitem__(self, index):
        annotation_path=self.annotation_lines[index].split(';')[1].split()[0]
        image=Image.open(annotation_path)
        image=self.get_random_data(image,self.input_shape,random=self.random)
        image=np.transpose(preprocess_input(np.array(image).astype(np.float32)),[2,0,1])
        y=int(self.annotation_lines[index].split(';')[0])
        return image,y
    def rand(self,a=0,b=1):
        return np.random.rand()*(b-a)+a

    def get_random_data(self,image,inpt_shape,jitter=.3,hue=.1,sat=1.5,val=1.5,random=True):

        image=cvtColor(image)
        iw,ih=image.size
        h,w=inpt_shape
        if not random:
            scale=min(w/iw,h/ih)
            nw=int(iw*scale)
            nh=int(ih*scale)
            dx=(w-nw)//2
            dy=(h-nh)//2

            image=image.resize((nw,nh),Image.BICUBIC)
            new_image=Image.new('RGB',(w,h),(128,128,128))

            new_image.paste(image,(dx,dy))
            image_data=np.array(new_image,np.float32)
            return image_data
        new_ar=w/h*self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
        scale=self.rand(.75,1.25)
        if new_ar<1:
            nh=int(scale*h)
            nw=int(nh*new_ar)
        else:
            nw=int(scale*w)
            nh=int(nw/new_ar)
        image=image.resize((nw,nh),Image.BICUBIC)
        #将图像多余的部分加上灰条
        dx=int(self.rand(0,w-nw))
        dy=int(self.rand(0,h-nh))
        new_image=Image.new('RGB',(w,h),(128,128,128))
        new_image.paste(image,(dx,dy))
        image=new_image
        #翻转图像
        flip=self.rand()<.5
        if flip: image=image.transpose(Image.FLIP_LEFT_RIGHT)
        rotate=self.rand()<.5
        if rotate:
            angle=np.random.randint(-15,15)
            a,b=w/2,h/2
            M=cv2.getRotationMatrix2D((a,b),angle,1)
            image=cv2.warpAffine(np.array(image),M,(w,h),borderValue=[128,128,128])
        #色域扭曲
        hue=self.rand(-hue,hue)
        sat=self.rand(1,sat) if self.rand()<.5 else 1/self.rand(1,sat)
        val=self.rand(1,val) if self.rand()<.5 else 1/self.rand(1,val)
        x=cv2.cvtColor(np.array(image,np.float32)/255,cv2.COLOR_RGB2HSV)#颜色空间转换
        x[...,1]*=sat
        x[...,2]*=val
        x[x[:,:,0]>360,0]=360
        x[:,:,1:][x[:,:,1:]>1]=1
        x[x<0]=0
        image_data=cv2.cvtColor(x,cv2.COLOR_HSV2RGB)*255
        return image_data

3.VGG19的网络构建代码net_VGG19.py

其实该代码可以直接去torch的官网下,而且如果想改动VGG网络结构,只需微调一下,就可以实现其代码了。

在torch官网能够下载到一些预训练的模型:

https://pytorch.org/vision/stable/models/vgg.html

比如如下的VGG16:

逐条解析该程序:

导包和下载网络权重:

import torch
import torch.nn as nn
​
model_urls = {
    "vgg19":  "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
}#权重下载网址,该地址在torch官网上可下载

VGG网络的类:

class VGG(nn.Module):
    def __init__(self, features, num_classes = 1000, init_weights = True, dropout = 0.5):
        #继承
        super(VGG,self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) # AdaptiveAvgPool2d使处于不同大小的图片也能进行分类
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),  # 完成4096的全连接
            nn.Linear(4096, num_classes), #对 num_classes的分类
        )
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)
​
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

网络结构设置:

def make_layers(cfg, batch_norm = False): # make_layers对输入的cfg进行循环
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M": # 对cfg进行输入循环,取第一个v
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)] # 把输入图像进行缩小
        else:
            #v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)
​
​
cfgs = {
    "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}
# 这部分代码比较讲究,可参考B站视频,地址贴在下面了
def vgg19(pretrained=False, progress=True,num_classes=2):
    model = VGG(make_layers(cfgs["VGG19"]))
    if pretrained:
        from torch.hub import load_state_dict_from_url
        state_dict = load_state_dict_from_url(model_urls['vgg19'],model_dir='./model' ,progress=progress)#预训练模型地址
        model.load_state_dict(state_dict)
    if num_classes != 1000:
        model.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),  # 随机删除一部分不合格
            nn.Linear(4096, 4096),
            nn.ReLU(True),  # 防止过拟合
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes),
        )
    return model
if __name__ == '__main__':
    in_data = torch.ones(1, 3, 224, 224)
    net = vgg19(pretrained=False, progress=True, num_classes=2)
    out = net(in_data)
    print(out)

4.训练得到pth模型参数文件的get_pth_file.py

生成训练集和测试集:

'''数据集'''
annotation_path = 'cls_train.txt' # 读取数据集生成的文件
with open(annotation_path,'r') as f:
    lines = f.readlines() # 拿到所有图片数据的地址,lines的数据类型是一个列表,其中存下了所有图片地址
    #print(type(lines)) # <class 'list'>
    
import numpy as np
np.random.seed(10101) # 函数用于生成指定随机数
np.random.shuffle(lines) # 数据打乱
num_val = int(len(lines)*0.2) # 用做测试的数据数量
#print(num_val) #-->266
num_train = len(lines)-num_val # 训练的数据的数量
#输入图像大小
input_shape=[224,224]   #导入图像大小
# 生成数据
from AAA_FX.revise_VGG19.data_process import DataGenerator
​
train_data = DataGenerator(lines[:num_train],input_shape,True)
val_data = DataGenerator(lines[num_train:],input_shape,False)
​
val_len=len(val_data)
print(val_len)#返回测试集长度

加载数据:

# 取黑盒子工具
"""加载数据"""
from torch.utils.data import DataLoader#工具取黑盒子,用函数来提取数据集中的数据(小批次)

gen_train=DataLoader(train_data,batch_size=4)#训练集batch_size读取小样本,规定每次取多少样本
gen_test=DataLoader(val_data,batch_size=4)#测试集读取小样本

构建网络:

'''构建网络'''
from net_VGG19 import vgg19

device=torch.device('cuda'if torch.cuda.is_available() else "cpu")#电脑主机的选择
net=vgg19(True, progress=True,num_classes=2)#定于分类的类别
net.to(device)

选择优化器和学习率的调整方法:

'''选择优化器和学习率的调整方法'''
lr=0.0001#定义学习率
optim=torch.optim.Adam(net.parameters(),lr=lr)#导入网络和学习率
sculer=torch.optim.lr_scheduler.StepLR(optim,step_size=1)#步长为1的读取

训练:

'''训练'''
epochs=20#读取数据次数,每次读取顺序方式不同
for epoch in range(epochs):
    total_train=0 #定义总损失
    for data in gen_train:
        img,label=data
        with torch.no_grad():
            img =img.to(device)
            label=label.to(device)
        optim.zero_grad()
        output=net(img)
        train_loss=nn.CrossEntropyLoss()(output,label).to(device)
        train_loss.backward()#反向传播
        optim.step()#优化器更新
        total_train+=train_loss #损失相加
    sculer.step()
    total_test=0#总损失
    total_accuracy=0#总精度
    for data in gen_test:
        img,label =data #图片转数据
        with torch.no_grad():
            img=img.to(device)
            label=label.to(device)
            optim.zero_grad()#梯度清零
            out=net(img)#投入网络
            test_loss=nn.CrossEntropyLoss()(out,label).to(device)
            total_test+=test_loss#测试损失,无反向传播
            accuracy=((out.argmax(1)==label).sum()).clone().detach().cpu().numpy()#正确预测的总和比测试集的长度,即预测正确的精度
            total_accuracy+=accuracy
    print("训练集上的损失:{}".format(total_train))
    print("测试集上的损失:{}".format(total_test))
    print("测试集上的精度:{:.1%}".format(total_accuracy/val_len))#百分数精度,正确预测的总和比测试集的长度

    torch.save(net.state_dict(),"DogandCat{}.pth".format(epoch+1))
    print("模型已保存")

5.预测代码predict.py

导入图像:

from PIL import Image
test_pth='.\\train\cat\cat.6.jpg'#设置可以检测的图像
test=Image.open(test_pth)

处理图片:

'''处理图片'''
from torchvision import transforms

transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
image=transform(test)

加载网络:

'''加载网络'''
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")#CPU与GPU的选择
net =vgg19()#输入网络

model=torch.load("./DogandCat8.pth",map_location=device)#已训练完成的结果权重输入
net.load_state_dict(model)#模型导入
net.eval()#设置为推测模式
image=torch.reshape(image,(1,3,224,224))#四维图形,RGB三个通
with torch.no_grad():
    out=net(image)
out=F.softmax(out,dim=1)#softmax 函数确定范围
out=out.data.cpu().numpy()
print(out)
a=int(out.argmax(1))#输出最大值位置
plt.figure()
list=['Cat','Dog']
plt.suptitle("Classes:{}:{:.1%}".format(list[a],out[0,a]))#输出最大概率的道路类型
plt.imshow(test)
plt.show()

VGG16的道理是一样的,这里略。 

6.预测VGG16与VGG19结果对比

VGG19的预测结果:

VGG16的预测结果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/585850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【git教程】

这里写目录标题 git是什么集中式版本控制系统和分布式版本控制系统git的优势git能做什么(常用)基础教程流程图介绍小节 常用Git命令速查表详解1、HEAD2、add3、commit4、branch5、merge6、rebasemerge和rebase区别7、reset8、revertrevert与reset的区别 git是什么 git是目前世…

【Java算法题】剑指offer_数据结构之02树

前言 刷题链接&#xff1a; https://www.nowcoder.com/exam/oj/ta?page2&tpId13&type265 2. 树 JZ55 二叉树的深度 思路&#xff1a;dep max_deepth(left,right)1&#xff0c;二叉树的深度为根节点到叶子节点&#xff0c;使用递归访问根节点的左孩子和右孩子&…

想要让数据更生动?试试这5种图表工具

在当今大数据时代&#xff0c;数据的利用和分析在各个领域的工作中起着重要的作用。因此&#xff0c;数据可视化图形工具已经成为数据分析的好帮手。事实上&#xff0c;数据可视化的本质是视觉对话。它通过图形手段清晰直观地表达信息&#xff0c;从数据中获得价值。然而&#…

Netty实战(九)

单元测试 一、什么是单元测试二、EmbeddedChannel 概述三、 使用 EmbeddedChannel 测试 ChannelHandler3.1 测试入站消息3.2 测试出站消息 一、什么是单元测试 单元测试的基本思想是&#xff1a;以尽可能小的区块测试代码&#xff0c;并且尽可能地和其他的代码模块以及运行时的…

Java: IO流

1.定义 IO流:存储和读取数据的解决方案 用于读写文件中的数据&#xff08;可以读写文件&#xff0c;或网络中的数据...) 2.IO流的分类 1.按着流的方向 1.输入流&#xff1a;读取 2.输出流&#xff1a;写出 2.按照操作文件类型 1.字节流&#xff1a;所有类型文件 体系&…

Redis:缓存击穿、缓存穿透与缓存雪崩的区别、解决方案

0、前言 近期学习redis相关原理&#xff0c;记录一下开发过程中Redis的一些常见问题及应对方法。 1、缓存穿透 一句话总结&#xff1a;先查redis发现没数据&#xff0c;再去数据库查发现还是没数据。 这种情况下缓存永远不会生效&#xff0c;数据库将承担巨大压力。 我们知道&…

前端食堂技术周刊第 84 期:第 96 届 TC39 会议、Deno 五周年、JavaScript 安全最佳实践、2023 Node.js 性能现状

By Midjournery 美味值&#xff1a;&#x1f31f;&#x1f31f;&#x1f31f;&#x1f31f;&#x1f31f; 口味&#xff1a;葡萄冰萃美式 食堂技术周刊仓库地址&#xff1a;https://github.com/Geekhyt/weekly 本期摘要 第 96 届 TC39 会议Deno 五周年JavaScript 安全最佳…

FreeRTOS:信号量

目录 一、信号量是什么二、二值信号量2.1二值信号量简介2.2创建二值信号量2.2.1函数 vSemaphoreCreateBinary()2.2.2函数xSemaphoreCreateBinary()2.2.3 函数 xSemephroeCreateBinaryStatic()2.2.4二值信号量创建过程分析 2.3释放信号量2.3.1函数 xSemaphoreGive ()2.3.2函数 x…

【MySQL学习6:多行输入函数——聚合函数及SQL书写和执行规则】

之前做的笔记都在有道云&#xff0c;之后会一点点将以前的笔记分享出来~ &#xff08;配图在笔记中查看&#xff09; MySQL学习6&#xff1a;多行输入函数——聚合函数及SQL书写和执行规则 SQL书写顺序&#xff1a;SQL99执行顺序&#xff1a;一、常见的聚合函数1. 常见的聚合函…

算法当中的时间、空间复杂度?

1.究竟什么是时间复杂度 时间复杂度是一个函数&#xff0c;它定性描述该算法的运行时间 时间复杂度就是用来方便开发者估算出程序运行的答题时间。 通常会估算算法的操作单元数量来代表程序消耗的时间&#xff0c;这里默认CPU的每个单元运行消耗的时间都是相同的。 假设算法的…

微服务架构之服务监控与追踪

与单体应用相比&#xff0c;在微服务架构下&#xff0c;一次用户调用会因为服务化拆分后&#xff0c;变成多个不同服务之间的相互调用&#xff0c;每个服务可能是由不同的团队开发&#xff0c;使用了不同的编程语言&#xff0c;还有可能部署在不同的机器上&#xff0c;分布在不…

【MySQL】MySQL间隙锁--幻读解决原理

文章目录 一、间隙锁概念二、测试间隙锁范围加锁三、测试等值间隙锁 一、间隙锁概念 当我们用范围条件而不是相等条件检索数据&#xff0c; 并请求共享或排他锁时&#xff0c;InnoDB 会给符合条件的已有数据记录的索引项加锁&#xff1b;对于键值在条件范围内但并不存在的记录…

八、视图集ModelViewSet(重点)

上一章&#xff1a; 七、Django DRF框架GenericAPIView--搜索&排序&分页&返回值_做测试的喵酱的博客-CSDN博客 下一章&#xff1a; 九、DRF生成API文档_做测试的喵酱的博客-CSDN博客 一、视图集ModelViewSet与ReadOnlyViesSet ModelViewSet视图集 与 ReadOnly…

第13届蓝桥杯Scratch国赛真题集锦

编程题 第 1 题 问答题 LED屏幕 题目说明 编程实现 LED屏幕 具体要求: 1).点击绿旗,在舞台中心区域出现由10 x 10方格组成的LED屏幕; 2).按下空格键,LED屏幕最外环方格全部点亮 (方格变为黄色) 3).LED屏幕每秒向内点亮一层,其它LED灯熄灭; 4).直到LED灯在最中心点亮2秒…

games101作业5

作业要求 • Renderer.cpp 中的 Render()&#xff1a;这里你需要为每个像素生成一条对应的光 线&#xff0c;然后调用函数 castRay() 来得到颜色&#xff0c;最后将颜色存储在帧缓冲区的相 应像素中。 • Triangle.hpp 中的 rayTriangleIntersect(): v0, v1, v2 是三角形的三个…

字节跳动测开岗面试居然这么简单....

因为读者里有不少刚入门测试的同学&#xff0c;这两天抽空整理了一份字节测开实习的面试题答案&#xff0c;说实话这个题目真挺简单的&#xff0c;如果你面大厂碰到此类面试题&#xff0c;也算是运气很好啦。大家也可以先自测一下&#xff0c;看看自己能不能答上来。 如果觉得…

vscode 插件 remote-ssh的安装及使用

文章目录 vscode 插件 remote-ssh的安装及使用windows VSCode(clangdremote-ssh) linux(clangd server) 开发环境问题问题1 : getPlatformForHost was canceled问题2 : host 主机不能联互联网问题3 : 每次都要输入密码 vscode 插件 remote-ssh的安装及使用 需要用到的东西1. r…

安装nodejs环境搭建vue项目的框架

说明&#xff1a;想要搭建一个vue项目的框架&#xff0c;先要安装nodejs环境&#xff1b;我的电脑已经安装过&#xff0c;先卸载掉&#xff0c;重新装一遍&#xff08;卸载nodejs参考&#xff1a;http://t.csdn.cn/jHmCU&#xff09; 一、安装nodejs环境 第一步&#xff1a;下…

ssh无密码链接

ssh的基本语法为 ssh host 然后输入密码才可以 如果a想要免密登录b&#xff0c;则a需要将自己的公钥放在b上&#xff0c;原理如下图&#xff1a; 例如&#xff0c;以aricoder这个用户登录的情况下&#xff0c;在服务器01上生成密钥&#xff0c;命令为 ssh-keygen -t rsa然后连…

前端面试题汇总大全二(含答案超详细,Vue,TypeScript,React,Webpack 汇总篇)-- 持续更新

前端面试题汇总大全&#xff08;含答案超详细&#xff0c;HTML,JS,CSS汇总篇&#xff09;-- 持续更新 前端面试题汇总二 五、Vue 篇1. 谈谈你对MVVM开发模式的理解&#xff1f;2. v-if 和 v-show 有什么区别&#xff1f;3. r o u t e 和 route和 route和router区别4.vue自定义…