卷积神经网络| 猫狗系列【AlexNet】

news2024/12/22 21:29:26

首先,搭建网络:

AlexNet神经网络原理图:

net代码:【根据网络图来搭建网络,不会的看看相关视频会好理解一些】

import torchfrom torch import nnimport torch.nn.functional as Fclass MyAlexNet(nn.Module):    def __init__(self):        super(MyAlexNet, self).__init__()#继承        self.c1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=11, stride=4, padding=2)#搭建第一层网络,输入通道3层,输出通道48,核11        self.ReLU = nn.ReLU()#激活函数        self.c2 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=5, stride=1, padding=2)#上面输出48,下面输入也是48,输出125,卷积核5        self.s2 = nn.MaxPool2d(2)#池化层        self.c3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=1, padding=1)        self.s3 = nn.MaxPool2d(2)        self.c4 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=1, padding=1)        self.c5 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=3, stride=1, padding=1)        self.s5 = nn.MaxPool2d(kernel_size=3, stride=2)        self.flatten = nn.Flatten()#平展层        self.f6 = nn.Linear(4608, 2048)        self.f7 = nn.Linear(2048, 2048)        self.f8 = nn.Linear(2048, 1000)        self.f9 = nn.Linear(1000, 2)#输出二分类网络    def forward(self, x):        x = self.ReLU(self.c1(x))        x = self.ReLU(self.c2(x))        x = self.s2(x)#池化层        x = self.ReLU(self.c3(x))        x = self.s3(x)        x = self.ReLU(self.c4(x))        x = self.ReLU(self.c5(x))        x = self.s5(x)        x = self.flatten(x)#平展层        x = self.f6(x)        x = F.dropout(x, p=0.5)#防止过拟合,有50%的网络随机失效        x = self.f7(x)        x = F.dropout(x, p=0.5)        x = self.f8(x)        x = F.dropout(x, p=0.5)        x = self.f9(x)        return xif __name__ == '__mian__':    x = torch.rand([1, 3, 224, 224])#张量形式数组    model = MyAlexNet()    y = model(x)

测试一下这个网络:

划分数据集:(8:2)(spilit_data

import osfrom shutil import copyimport randomdef mkfile(file):    if not os.path.exists(file):        os.makedirs(file)# 获取data文件夹下所有文件夹名(即需要分类的类名)file_path = 'D:/Users/Twilight/PycharmProjects/AlexNet/data_name'flower_class = [cla for cla in os.listdir(file_path)]# 创建 训练集train 文件夹,并由类名在其目录下创建5个子目录mkfile('data/train')for cla in flower_class:    mkfile('data/train/' + cla)# 创建 验证集val 文件夹,并由类名在其目录下创建子目录mkfile('data/val')for cla in flower_class:    mkfile('data/val/' + cla)# 划分比例,训练集 : 验证集 = 8:2split_rate = 0.2# 遍历所有类别的全部图像并按比例分成训练集和验证集for cla in flower_class:    cla_path = file_path + '/' + cla + '/'  # 某一类别的子目录    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称    num = len(images)    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称    for index, image in enumerate(images):        # eval_index 中保存验证集val的图像名称        if image in eval_index:            image_path = cla_path + image            new_path = 'data/val/' + cla            copy(image_path, new_path)  # 将选中的图像复制到新路径        # 其余的图像保存在训练集train中        else:            image_path = cla_path + image            new_path = 'data/train/' + cla            copy(image_path, new_path)        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar    print()print("processing done!")

生成的新的文件夹:

由于数据集量太大,划分完了我还删了很多图片,我的data里面只有1000张,训练集猫狗分别400,测试集猫狗分别100。【跑不动根本跑不动】

训练代码:(train)

import torchfrom torch import nnfrom net import MyAlexNetimport numpy as npfrom torch.optim import lr_schedulerimport osfrom torchvision import transformsfrom torchvision.datasets import ImageFolderfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt# 解决中文显示问题(乱码)plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = FalseROOT_TRAIN = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/train'#数据集路径训练集ROOT_TEST = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/val'# 将图像的像素值归一化到【-1, 1】之间normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#train_transform = transforms.Compose([#训练集    transforms.Resize((224, 224)),#224*224    transforms.RandomVerticalFlip(),#随机垂直    transforms.ToTensor(),#转化为张量    normalize])#归一化val_transform = transforms.Compose([#验证集    transforms.Resize((224, 224)),    transforms.ToTensor(),    normalize])train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)#批次32,打乱val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'#数据导入显卡里面model = MyAlexNet().to(device)#把数据送到神经网络中,然后输到显卡里面# 定义一个损失函数loss_fn = nn.CrossEntropyLoss()# 定义一个优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)#随机梯度下降法,把模型参数传给优化器,学习率0.01# 学习率每隔10轮变为原来的0.5lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 定义训练函数def train(dataloader, model, loss_fn, optimizer):#数据,模型,学习率,优化器传入    loss, current, n = 0.0, 0.0, 0#指示器    for batch, (x, y) in enumerate(dataloader):        image, y = x.to(device), y.to(device)        output = model(image)#进行训练        cur_loss = loss_fn(output, y)#看误差        _, pred = torch.max(output, axis=1)#_代表不关心返回的最大值是多少。只需要得到pred        cur_acc = torch.sum(y==pred) / output.shape[0]#算精确率        # 反向传播        optimizer.zero_grad()#梯度降为0        cur_loss.backward()#反向传播        optimizer.step()#更新梯度        loss += cur_loss.item()#Loss值累加起来(一轮很多批次)        current += cur_acc.item()#准确度加起来        n = n+1#轮    train_loss = loss / n#计算这一轮学习的学习率(每一批的)    train_acc = current / n    print('train_loss' + str(train_loss))    print('train_acc' + str(train_acc))#训练精确的    return train_loss, train_acc#返回后面可视化用# 定义一个验证函数def val(dataloader, model, loss_fn):    # 将模型转化为验证模型    model.eval()    loss, current, n = 0.0, 0.0, 0    with torch.no_grad():        for batch, (x, y) in enumerate(dataloader):            image, y = x.to(device), y.to(device)            output = model(image)            cur_loss = loss_fn(output, y)            _, pred = torch.max(output, axis=1)            cur_acc = torch.sum(y == pred) / output.shape[0]            loss += cur_loss.item()            current += cur_acc.item()            n = n + 1    val_loss = loss / n    val_acc = current / n    print('val_loss' + str(val_loss))    print('val_acc' + str(val_acc))    return val_loss, val_acc# 定义画图函数def matplot_loss(train_loss, val_loss):    plt.plot(train_loss, label='train_loss')    plt.plot(val_loss, label='val_loss')    plt.legend(loc='best')    plt.ylabel('loss')    plt.xlabel('epoch')    plt.title("训练集和验证集loss值对比图")    plt.show(block=True)def matplot_acc(train_acc, val_acc):    plt.plot(train_acc, label='train_acc')    plt.plot(val_acc, label='val_acc')    plt.legend(loc='best')    plt.ylabel('acc')    plt.xlabel('epoch')    plt.title("训练集和验证集acc值对比图")    plt.show(block=True)# 开始训练loss_train = []acc_train = []loss_val = []acc_val = []epoch = 20 #20轮min_acc = 0for t in range(epoch):    lr_scheduler.step()#每十步分析一下学习率    print(f"epoch{t+1}\n-----------")    train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)    val_loss, val_acc = val(val_dataloader, model, loss_fn)    loss_train.append(train_loss)#写到集合里头    acc_train.append(train_acc)    loss_val.append(val_loss)    acc_val.append(val_acc)    # 保存最好的模型权重    if val_acc >min_acc:#如果模型精确度大于0        folder = 'save_model'        if not os.path.exists(folder):#如果文件夹不存在            os.mkdir('save_model')#生成        min_acc = val_acc        print(f"save best model, 第{t+1}轮")        torch.save(model.state_dict(), 'save_model/best_model.pth')    # 保存最后一轮的权重文件    if t == epoch-1:        torch.save(model.state_dict(), 'save_model/last_model.pth')matplot_loss(loss_train, loss_val)matplot_acc(acc_train, acc_val)print('Done!')

最好的模型保存,嘻嘻。

生成的loss、acc图​:​

(效果其实是非常不好的,因为数据量太少了哈哈,然后参数某些地方也可以再调一下)

测试代码(test)

import torchfrom net import MyAlexNetfrom torch.autograd import Variablefrom torchvision import datasets, transformsfrom torchvision.transforms import ToTensorfrom torchvision.transforms import ToPILImagefrom torchvision.datasets import ImageFolderfrom torch.utils.data import DataLoaderROOT_TRAIN = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/train'#数据集路径训练集ROOT_TEST = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/val'# 将图像的像素值归一化到【-1, 1】之间normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])train_transform = transforms.Compose([    transforms.Resize((224, 224)),    transforms.RandomVerticalFlip(),    transforms.ToTensor(),    normalize])val_transform = transforms.Compose([    transforms.Resize((224, 224)),    transforms.ToTensor(),    normalize    ])train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)#变张量val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'model = MyAlexNet().to(device)# 加载模型model.load_state_dict(torch.load("D:/Users/Twilight/PycharmProjects/AlexNet/save_model/best_model.pth"))# 获取预测结果classes = [    "cat",    "dog",]# 把张量转化为照片格式,后面可视化show = ToPILImage()# 进入到验证阶段model.eval()for i in range(10):#验证前十张    x, y = val_dataset[i][0], val_dataset[i][1]    show(x).show()    x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=True).to(device)#把值传入到显卡里面    x = torch.tensor(x).to(device)    with torch.no_grad():        pred = model(x)        predicted, actual = classes[torch.argmax(pred[0])], classes[y]        print(f'predicted:"{predicted}", Actual:"{actual}"')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/723685.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux系统之neofetch工具的基本使用

Linux系统之neofetch工具的基本使用 一、neofetch工具介绍1.1 neofetch简介1.2 neofetch特点 二、检查本地环境2.1 检查操作系统版本2.2 检查内核版本 三、安装neofetch工具3.1 配置yum仓库3.2 安装neofetch3.3 查看neofetch版本 四、neofetch工具的基本使用4.1 直接使用neofet…

quilt data-Working with the Catalog

Quilt Catalog 是 Quilt 的第二部分。它提供了一个在您的 S3 存储桶上的界面,将 Quilt 的数据包和搜索等功能带到了 Web 界面上。 请注意,您可以在不使用 Quilt Catalog 的情况下使用 Quilt Python API,但它们是设计为配合使用的。 简要介绍…

【实现openGauss5.0企业版一主一备搭建部署】

【实现openGauss5.0企业版一主一备搭建部署】 🔻 前言🔻 一、安装前准备🔰 1.1 openGauss安装包下载🔰 1.2 安装环境准备⛳ 1.2.1 硬件环境要求⛳ 1.2.2 软件环境要求⛳ 1.2.3 软件依赖要求⛳ 1.2.4 修改 hosts 和 hostname&#…

【react】创建启动react项目和跨域代理:

文章目录 1、创建启动react项目:2、跨域代理:【1】文档:[https://create-react-app.dev/docs/proxying-api-requests-in-development/](https://create-react-app.dev/docs/proxying-api-requests-in-development/)【2】src/setupProxy.js: 1…

阿姆斯特丹大学Max Welling教授-深度学习和自然科学

目录 简介 AI4Science & Science4AI 深度学习简介 AI4Science Science4AI 总结/结束语 参考 简介 人工智能一直与自然科学有着深厚的联系。 人工神经网络最初被认为是生物神经网络的抽象,许多后续算法(例如强化学习)也是如此。 神经…

springcloud actuator暴露端点漏洞修复

前段时间网络安全的同事突然通知系统漏洞,swagger漏洞和暴露多余端点等,可能会泄露信息。刚开始只是修改了相关配置。如下: 更改config配置 management:security:enabled: true security:user:name: xxxpassword: xxxbasic:enabled: trueen…

配置tensorflow1.15版本遇到的问题:conda环境管理/tensorflow历史版本下载/pycharm中如何使用conda中的虚拟环境

0、前言: 我之前在做配置环境,或者不懂的操作时,总是遇到问题在csdn或者网上搜就行了,然后解决问题之后,也不知道期间搜了哪些知识。也记不住一些修改的地方,这就导致,我十分担心好不容易搭好的…

spring系列-SpringCloud

SpringCloud概述 微服务概述 什么是微服务 目前的微服务并没有一个统一的标准,一般是以业务来划分 将传统的一站式应用,拆分成一个个的服务,彻底去耦合,一个微服务就是单功能业务,只做一件事。 与微服务相对的叫巨石 …

“提高个人生产力:思维导图在时间管理和计划中的应用“

在高效成为当今时代职场人高频谈论的一个词后,时间管理和计划的重要性也日渐显现。一个好的时间管理和计划可以在不知不觉中有效帮助我们更加合理的安排时间,保证工作的有序进行和按时完成。通过合理的协调工作与休息之间的关系,避免我们浪费…

15、服务端实战:数据库工具封装

在了解完 NestJS 的基础配置之后,服务端的内容将引来一个比较重要的环节:数据库。 因为数据库的内容比较多,所以相关内容将分为两个章节来展开讨论: 数据库工具封装 - 将封装统一的数据库操作工具类,方便后期开发于集…

识别肿瘤内微生物的生物信息学工具—MEGA

谷禾健康 已有研究证明宿主微生物在癌症预防和治疗反应中的关键作用,了解宿主微生物和癌症之间的相互作用,可以推动癌症诊断和微生物治疗(即用微生物作为药物)。 然而肿瘤内微生物组数据通常是复杂的,想要厘清相互关系也是极为困难的&#xf…

低代码平台——少量编码即可快速生成应用程序

低代码平台,即无需编码或通过少量代码就可以快速生成应用程序的开发平台。 低代码平台面向的是IT或者平民程序员,解决传统软件开发模式带来的周期长、成本高等问题,客户群体主要为软件开发公司或者拥有IT的中大型企业。而零代码(N…

API接口测试工具的几个特色

API接口测试工具在软件开发过程中起着举足轻重的作用。它们帮助测试人员快速发现和解决API接口的问题,并确保系统的稳定性和性能。本文将介绍API接口测试工具的几个特色,以及为什么它们对测试人员来说非常重要。 首先,API接口测试工具的一个特…

【uniapp】学习之【生命周期】

uniapp生命周期 uni-app框架的生命周期分为两种 : 应用中的生命周期 和 页面内的生命周期 uni-app 应用生命周期 uni-app 页面生命周期

微信公众号本地开发调试 - 无公网IP —— 内网穿透

文章目录 前言1. 配置本地服务器2. 内网穿透2.1 下载安装cpolar内网穿透2.2 创建隧道 3. 测试公网访问4. 固定域名4.1 保留一个二级子域名4.2 配置二级子域名 5. 使用固定二级子域名进行微信开发 前言 在微信公众号开发中,微信要求开发者需要拥有自己的服务器资源来…

软考:中级软件设计师:进程死锁,死锁的预防和避免,银行算法家,

软考:中级软件设计师:进程死锁 提示:系列被面试官问的问题,我自己当时不会,所以下来自己复盘一下,认真学习和总结,以应对未来更多的可能性 关于互联网大厂的笔试面试,都是需要细心准备的 &…

STM32:使用RS485和多摩川编码器通信

本文主要讲使用STM32F767和绝对式多摩川TS5700N8501编码器通信的流程和注意事项。 首先使用STM32CubeMX生成RS485驱动部分功能代码,注意该款编码器的波特率是2.5Mbps。 注意使能的GPIO可以使用其他管脚,我们的主控板使用的是PA8。前期可以这么配置。 配…

zabbix的安装

前提 作为一个运维,需要会使用监控系统查看服务器系统性能、应用服务状态和网站流量指标等,利用监控系统的数据去了解网站上线发布的结果和健康状态。 利用一个优秀的监控软件,我们可以: ●通过一个友好的界面进行浏览整个网站所有的服务器…

Linux—实操篇:用户管理

1、基本介绍 Linux系统是一个多用户多任务的操作系统,任何一个要使用系统资源的用户,都必须首先向系统管理员申请一个 账号,然后以这个账号的身份进入系统。 2、添加用户 基本语法: useradd 用户名 细节说明: 1、…

【Kafka】Kafka基础操作笔记

【Kafka】Kafka基础操作笔记 文章目录 【Kafka】Kafka基础操作笔记1. 两种模式1.1 点对点模式1.2 发布/订阅模式 2. 基础架构3. Topic命令行操作3.1 查看 Topic 操作3.2 创建 Topic3.3 查看所有 Topic3.4 查看 Topic 的详情3.5 修改分区数3.6 删除 Topic 1. 两种模式 Kafka作为…