深度学习pytorch实战三:VGG16图像分类篇自建数据集图像分类三类

news2024/11/24 18:27:14

1.自建数据集与划分训练集与测试集
2.模型相关知识
3.model.py——定义AlexNet网络模型
4.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数
5.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

一、自建数据集与划分训练集与测试集

1.自建数据文件夹

  首先我们确定这次分类种类,采用爬虫、官网数据集和自己拍照的照片获取三类,准备个文件夹,里面包含三个文件夹,文件夹名字随便取,最好是所属种类英文,每个文件夹照片数量最好一样多,五百多张以上。如我选了蒲公英,玫瑰,郁金香三类,如data_set包含flowers_data,它包含flowers_photos,它包含三个文件夹,分别是三个类文件夹。

2.划分训练集与测试集

这里需要使用通用的划分数据代码,这次是与flowers_data同一目录下运行

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")



最后运行,在flowers_data会多两个文件,是train和val(训练集和测试集)

二、模型相关知识

之前有文章介绍模型,如果不清楚可以点下链接转过去学习

深度学习卷积神经网络CNN之 VGGNet模型主vgg16和vgg19网络模型详解说明(理论篇)

在这里插入图片描述

三、model.py——定义AlexNet网络模型

这里还是直接复制给出原模型,不用改参数。

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

四、train.py——模型训练,加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数

在63行修改为3,因为只有三类

net = vgg(model_name=model_name, num_classes=3, init_weights=True)

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set1", "flower_data1")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 64
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=3, init_weights=True)%%%%%%%%这一行
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 10
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

训练结果截图如下

五、predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "1.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)
    
    # create model
    model = vgg(model_name="vgg16", num_classes=5).to(device)
    # load model weights
    weights_path = "./vgg16Net.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

在网上下载了一郁金香的图片,使用VGG16网络查看是否可以将图片种类正确识别。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Taro开发】-文字展开收起组件(十五)

Taro小程序开发 系列文章的所有文章的目录 【Taro开发】-初始化项目(一) 【Taro开发】-路由传参及页面事件调用(二) 【Taro开发】-taro-ui(三) 【Taro开发】-带token网络请求封装(四&#x…

Sonar:Win10搭建SonarQube9.8服务

需求描述 公司为项目代码配置了Sonar检测,最初只是想调研在VSCode中同步远程检测的方法(现在请参考Sonar:VSCode配置SonarLint/SonarLint连接SonarQube);结果并没有找到靠谱的教程。。在度娘的信息海洋胡乱扑腾两天后…

Docker(七)--Docker数据卷管理及插件

文章目录一、Docker 数据卷管理1.bind mount2.docker managed volume3.bind mount与docker managed volume对比二、跨节点存储convoy卷插件一、Docker 数据卷管理 在实际使用过程中,我们需要把容器和数据进行隔离,因为容器在使用过程中可能随时要进行销…

C++:set和map(模拟实现)

目录 关联式容器 键值对 树形结构的关联式容器 set的介绍 set的使用 map的介绍 map的使用 multiset的介绍 multimap的介绍 底层结构 AVL树的概念 AVL树节点的定义 AVL树的旋转 左单旋 右单旋 先右单旋再左单旋 先左单旋再右单旋 模拟实现AVL树 红黑树 红黑树…

【C++】30h速成C++从入门到精通(stack、queuepriority_queue以及deque介绍)

stackstack的介绍https://cplusplus.com/reference/stack/stack/?kwstackstack是一种容器适配器,专门在具有后进先出操作的上下文环境中,其删除只能从容器的一端进行元素的插入与提取操作。stack是作为容器适配器被实现的,容器适配器即是对特…

详解一致性哈希算法

在单机系统中,所有的数据都存储在同一个服务器下,当数据量越来越多的时候,超过了单机存储容量的上限,就需要使用分布式存储系统,在分布式存储系统重,数据会被拆分到不同的存储服务下,减少单机服…

[数据结构]:12-快速排序(顺序表指针实现形式)(C语言实现)

目录 前言 已完成内容 快速排序实现 01-开发环境 02-文件布局 03-代码 01-主函数 02-头文件 03-PSeqListFunction.cpp 04-SortCommon.cpp 05-SortFunction.cpp 结语 前言 此专栏包含408考研数据结构全部内容,除其中使用到C引用外,全为C语言代…

【Linux】canal1.1.7同步MySQL8.0.3和Redis

目录前言一、MySQL8配置1. 修改my.cnf2. 重启mysql3. 建用户、授权二、Canal服务端配置1. 下载2. 修改配置3. 启动服务与验证三、Canal客户端编写1. yml配置文件添加canal服务端配置信息和Redis信息2. 配置pom文件3. 代码4. MySQL建表storage.storage5. 启动客户端与验证参考前…

中微8S6990使用过程的一些记录--GPIO初始化、定时器、PWM、ADC、休眠等外设的配置和使用

前言 最近把一款产品的代码从新唐MS51移植到了中微8S6990平台上,记录下移植过程遇到的各种情况。 目录前言定时器初始化、中断服务函数GPIO配置ADC模数转换初始化PWM初始化Main函数休眠的一些注意事项最后定时器初始化、中断服务函数 void TMR0_Config(void) {/*(…

keepalived+nginx 双机热备搭建

keepalivednginx 双机热备搭建一、准备工作1.1 准备两台centos7.91.2 nginx 与 keepalived软件 双机安装1.3 ip分配1.4 修改主机名1.5 关闭selinux(双机执行)1.6 修改hosts(双机执行)二、安装keepalived2.1 执行一下命令安装keepa…

MidiaPipe +stgcn(时空图卷积网络)实现人体姿态判断(单目标)

文章目录前言Midiapipe关键点检测stgcn 姿态评估效果前言 冒个泡,年少无知吹完的牛皮是要还的呀。 那么这里的话要做的一个东西就是一个人体的姿态判断,比如一个人是坐着还是站着还是摔倒了,如果摔倒了我们要做什么操作,之类的。…

【模型复现】-alexnet,nn.Sequential顺序结构构建网络

深度卷积神经网络(AlexNet) 在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机。虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意。一方面&#…

第五章 事务管理

1.事务概念 *什么是事务:事务是数据库操作最基本单元,逻辑上是一组操作,要么都成功,要么都失败 *事务的特性(ACID):原子性、隔离性、一致性、持久性 2.搭建事务操作环境 *模拟场景&#xff…

uart串口接收模块

uart串口接收模块 1、UART(异步串行接口) 串行通信:指利用一条数据线将资料一位位的顺序传输。   异步通信:以一个字符为传输单位,通信中两个字符间的时间间隔是不固定的,然而在同一个字符的两个相邻位代…

【微信小程序】-- 页面事件 - 下拉刷新(二十五)

💌 所属专栏:【微信小程序开发教程】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &…

高盐废水除钙镁的技术解析

高盐废水指含有机物和至少总溶解固体(totaldissolvedsolids,tds)的质量分数大于3.5%的废水,具有水量大,无机盐离子k、na、ca2、mg2、cl-、so42-等含量高,水质水量变化大,成分复杂,难生化降解等特…

2023年中职网络安全竞赛——CMS网站渗透解析

需求环境可私信博主 解析如下: CMS网站渗透 任务环境说明: 服务器场景:Server2206(关闭链接) 服务器场景操作系统:未知 1.使用渗透机对服务器信息收集,并将服务器中网站服务端口号作为flag提交; Flag:8089

华为套件生态

华为套件生态前言蓝牙设备华为耳机华为鼠标智慧互联超级终端多屏协同远程访问文件共享华为电脑管家我的设备控制中心前言 华为的手机、平板、电脑、耳机、手环、手表等设备可以组成华为生态。以下分享一些生态体验。 蓝牙设备 华为耳机 快速连接 在手机/电脑附近打开华为耳…

里奇RIDGID管线定位仪/探测仪维修SR-20 SR-24 SR-60

美国里奇SeekTech SR-20管线定位仪对于初次使用定位仪的用户或经验丰富的用户,都同样可以轻易上手使用SR-20。SR-20提供许多设置和参数,使得大多数复杂的定位工作变得很容易。此外,当你在不复杂的环境下完成些基本的定位工作时,这…

软件测试7

一 CS和BS软件架构 CS:客户端-服务器端,BS:浏览器端-服务器端 区别总结: 1.效率:c/s效率高,某些内容已经安装在系统中了,b/s每次都要加载最新的数据 2.升级:b/s无缝升级&#xff0c…