卷积神经网络实现彩色图像分类 - P2

news2024/9/29 13:23:47
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P2周:彩色识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 硬件设备
    • 数据准备
      • 数据集下载与加载
      • 数据集预览
      • 数据集准备
    • 模型设计
    • 模型训练
      • 超参数设置
      • helper函数
      • 正式训练
    • 结果呈现
  • 总结与心得体会

上周使用Pytorch构建卷积神经网络,实现了MNIST手写数字的识别,这周的目标是CIFAR10中复杂的彩色图像分类。


环境

  • 系统:Linux
  • 语言: Python 3.8.10
  • 深度学习框架:PyTorch 2.0.0+cu118

步骤

环境设置

包引用

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from torchinfo import summary # 方便像tensorflow一样打印模型

硬件设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

数据集下载与加载

train_dataset = datasets.CIFAR10(root='data', train=True, 
					download=True, transform=transforms.ToTensor()) # 不要忘记这个transform
test_dataset = datasets.CIFAR10(root='data', train=False, 
					download=True, transform=transforms.ToTensor())

数据集预览

image, label = train_dataset[0]
print(image.shape)
plt.figure(figsize=(20,4))
for i in range(20):
	image, label = train_dataset[i]
	plt.subplot(2, 10, i+1)
	plt.imshow(image.numpy().transpose(1,2,0)
	plt.axis('off')
	plt.title(label) # 加载的数据集没有对应的名称,暂时展示它们的id

数据集预览

数据集准备

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

class Model(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		# 3x3的卷积无padding每次宽高-2
		# 2x2的最大池化,每次宽高缩短为原来的一半
		# 32x32 -> conv1 -> 30x30 -> maxpool -> 15x15
		self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
		# 15x15 -> conv2 -> 13x13 -> maxpool -> 6x6
		self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
		# 6x6 -> conv3 -> 4x4 -> maxpool -> 2x2
		self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
		self.maxpool = nn.MaxPool2d(2),
		self.flatten = nn.Flatten(),
		self.fc1 = nn.Linear(2*2*128, 256)
		self.fc2 = nn.Linear(256, num_classes)

	def forward(self, x):
		x = F.relu(self.conv1(x))
		x = self.maxpool(x)

		x = F.relu(self.conv2(x))
		x = self.maxpool(x)

		x = F.relu(self.conv3(x))
		x = self.maxpool(x)

		x = self.flatten(x)

		x = F.relu(self.fc1(x))
		x = self.fc2(x)
		return x

model = Model(10).to(device)
summary(model, input_size=(1, 3, 32, 32))

模型结构图

模型训练

超参数设置

learning_rate = 1e-2
epochs = 10
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

helper函数

def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset)
	num_batches = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)

		preds = model(x)
		loss = loss_fn(preds, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()

	train_loss /= num_batches
	train_acc /= size

	return train_loss, train_acc

def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	num_batches = len(test_loader)

	test_loss, test_acc = 0, 0
	with torch.no_grad():
		for x, y in test_loader:
			x, y = x.to(device), y.to(device)
			
			preds = model(x)
			loss = loss_fn(preds, y)
			
			test_loss += loss.item()
			test_acc += (preds.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= num_batches
	test_acc /= size

	return test_loss, test_acc

def fit(train_loader, test_loader, model, loss_fn, optimizer, epochs):
	train_loss, train_acc = [], []
	test_loss, test_acc = [], []
	for epoch in range(epochs):
		model.train()
		epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
		model.eval()
		epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)

		train_loss.append(epoch_train_loss)
		train_acc.append(epoch_train_acc)
		test_loss.append(epoch_test_loss)
		test_acc.append(epoch_test_acc)
	return train_loss, train_acc, test_loss, test_acc

正式训练

train_loss, train_acc, test_loss, test_acc = 
				fit(train_loader, test_loader, model, loss_fn, optimizer, 20)

训练结果

结果呈现

series = range(len(train_loss))
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(series, train_loss, label='train loss')
plt.plot(series, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1,2,2)
plt.plot(series, train_acc, label='train accuracy')
plt.plot(series, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

实验结果
从结果图可以发现,模型应该还没收敛,将epoch设置为30,重新跑一遍模型。
实验结果2
可以看出20个epoch后,训练集上的正确率持续增长,在验证集上的正确率几乎就不再增长了,符合过拟合的特征。需要对模型进行改进才能提升正确率了。


总结与心得体会

通过本周的学习,掌握了使用pytorch编写一个完整深度学习的过程,包括环境的配置、数据的准备、模型定义与训练、结果分析呈现等步骤,并且掌握了通过pytorch的API组建一个简单的卷积神经网络的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/852956.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络嗅探,大神都在用这10个抓包工具

下午好,我的网工朋友。 前两天发了一篇网工能干的工作大科普,没看过的看这:《不得不说,网工能干的活也太多了吧》。 然后有小友就说,里面的有些工作岗位要求,自己不知道从哪去补充知识,希望可…

【Java并发】synchronized关键字的底层原理

文章目录 1.synchronized作用2.synchronized加锁原理3.monitor锁4.synchronized锁的优化4.1.自适应性自旋锁4.2.偏向锁4.3.轻量级锁4.3.重量级锁 5.总结 1.synchronized作用 synchronized是Java提供一种隐式锁,无需开发者手动加锁释放锁。保证多线程并发情况下数据…

dubbo之高可用

负载均衡 概述 负载均衡是指在集群中,将多个数据请求分散到不同的单元上执行,主要是为了提高系统的容错能力和对数据的处理能力。 Dubbo 负载均衡机制是决定一次服务调用使用哪个提供者的服务。 策略 在Dubbo中提供了7中负载均衡策略,默…

1.0 Python 标准输入与输出

python 是一种高级、面向对象、通用的编程语言,由Guido van Rossum发明,于1991年首次发布。python 的设计哲学强调代码的可读性和简洁性,同时也非常适合于大型项目的开发。python 语言被广泛用于Web开发、科学计算、人工智能、自动化测试、游…

C#实现三菱FX-3U SerialOverTcp

设备信息 测试结果 D值测试 Y值写入后读取测试 协议解析 三菱FX 3U系列PLC的通信协议 1. 每次给PLC发送指令后,必须等待PLC的应答完成才能发送下一条指令; 2. 报文都是十六进制ASCII码的形式 3. 相关指令 指令 命令码(ASCII码) 操作原件 …

七牛云获取qn(url、bucket、access-key、secret-key)

1.注册账号 2.access-key和secret-key: 点击“密钥管理” 复制AK和SK即可 域名: bucket: 这个就是对象存储空间名字 先新建一个空间(没买需要先购买),步骤如下: 填写存储空间名字&#xff0…

“OSError: [WinError 1455]页面文件太小,无法完成操作。”解决方案

按照下面步骤, 会发现电脑默认情况下是没有给D盘分配虚拟内存的, 所以将Python装在D盘的朋友, 在跑程序时, 没有分配虚拟内存, 自然就遇到了上面的问题, 所以根本操作只要给D盘分配虚拟内存即可 第一步:鼠标右击我的电脑 (此电脑),点击属性进入以下界面 …

PostMessage/SendMessage在不同线程的调用探究

PostMessage和SendMessage是我们比较常用的windows API,最近也探究这两个api在调用之后,执行的线程问题,发现如下结论: 仅仅是RegisterClass注册类之后,调用createwindow第一个参数通过:(const TCHAR*)base…

three.js 地球与卫星绕地飞行【真实三维数据】

&#xff08;真实经纬度运行轨迹&#xff09; 完整代码 <template><div class"home3dMap" id"home3dMap" v-loading"loading"></div> </template><script> import * as THREE from three import { OrbitControl…

【K8S 的二进制搭建】

目录 一、二进制搭建 Kubernetes v1.201、准备环境 二、操作系统初始化配置三、部署 etcd 集群1、准备签发证书环境2、在 master01 节点上操作1、生成Etcd证书 3、在 node01 节点上操作4、在 node02 节点上操作 四、部署 Master 组件五、部署 docker引擎六、部署 Worker Node 组…

【算法|数组】手撕经典二分法

算法|数组——二分查找 文章目录 算法|数组——二分查找引言二分查找左闭右闭写法左闭右开写法 总结 引言 首先学习这个算法之前需要了解数组知识&#xff1a;数组。 大概介绍以下&#xff1a; 数组是存储在连续内存空间上的相同类型数据的集合。数组下标都是从0开始。数组在…

【SpringCloud】Gateway服务网关

Spring Cloud Gateway 是 Spring Cloud 的一个全新项目&#xff0c;该项目是基于 Spring 5.0&#xff0c;Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关&#xff0c;它旨在为微服务架构提供一种简单有效的统一的 API 路由管理方式。 1.为什么需要网关…

“实现数字化转型:探索会议OA项目的高级技术与创新应用“

文章目录 引言&#xff1a;1.项目背景和需求分析&#xff1a;2.技术选型和架构设计&#xff1a;3.项目实现和功能亮点&#xff1a;3.0 layui实现登录及注册3.1 会议管理模块3.1.1 会议发布3.1.2 我的会议3.1.3 我的审批3.1.4 会议通知3.1.5 待开会议3.1.6 历史会议3.1.7 所有会…

【Leetcode】层次遍历||树深度||队列

step by step. 题目&#xff1a; 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3示例 2&#xff1a; 输入&am…

【LeetCode】数据结构题解(13)[设计循环链表]

设计循环链表 &#x1f609; 1.题目来源&#x1f440;2.题目描述&#x1f914;3.解题思路&#x1f973;4.代码展示 所属专栏&#xff1a;玩转数据结构题型❤️ &#x1f680; >博主首页&#xff1a;初阳785❤️ &#x1f680; >代码托管&#xff1a;chuyang785❤️ &…

WordPress使用【前端投稿】功能时为用户怎么添加插入文章标签

在使用Wordpress做前端投稿功能的时候&#xff0c;可能需要用户填写文章标签&#xff0c;在插入文章的时候很多人不知道怎么把这些标签插入进去&#xff0c;下面这篇文章来为大家带来WordPress使用前端投稿功能时插入文章标签方法。 在Wordpress里 wp_insert_post 此函数的作…

Vue3核心笔记

文章目录 1.Vue3简介2.Vue3带来了什么1.性能的提升2.源码的升级3.拥抱TypeScript4.新的特性 一、创建Vue3.0工程1.使用 vue-cli 创建2.使用 vite 创建 二、常用 Composition API1.拉开序幕的setup2.ref函数3.reactive函数4.Vue3.0中的响应式原理vue2.x的响应式Vue3.0的响应式 5…

冠达管理:夜盘是什么意思?

跟着现代社会的快节奏开展&#xff0c;股票商场的运营也因而发生了相应的改变。夜盘是指在正常的买卖时刻外特别敞开的买卖商场&#xff0c;为出资者供给更多的买卖时刻和机会。夜盘的概念在我国自2002年引入以来逐渐被越来越多的出资者所熟知。那么夜盘究竟是什么&#xff1f;…

Collection 中的简单命令

一、list 集合常用的方法 LIst&#xff1a;元素存取有序、可以保存重复元素、有下标。可以存储 null 值。 ArrayList&#xff1a;底层数组结构&#xff0c;有 index&#xff0c;查询快 LinkList&#xff1a; 底层链表结构&#xff0c;增删快 添加 //确保此集合包含指定的元素…

操作系统—进程管理

这里写目录标题 进程特点并行和并发进程的状态进程的控制结构PCB包含什么信息 进程上下文切换进程上下文切换发生的场景 进程的通信方式 线程什么是线程线程的优缺点进程线程的区别线程的上下文切换线程的实现用户线程和内核线程的对应关系 用户线程如何理解、优缺点内核线程如…