PyTorch搭建AlexNet训练集

news2024/11/18 7:39:19

本次项目是使用AlexNet实现5种花类的识别。

训练集搭建与LeNet大致代码差不多,但是也有许多新的内容和知识点。

1.导包,不必多说。

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

2.指定设备

device函数用来指定在训练过程中所使用的设备:如果有可用的GPU,那么使用第一块GPU,如果没有就默认使用cpu。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

 3.数据预处理函数

单独定义出来,当key为“train”或为“val”时,返回数据集要使用的一系列预处理方法。

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),   # 把图片重新裁剪为224*224
                                 transforms.RandomHorizontalFlip(),  # 水平方向随机翻转
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)])}

4.获取数据集的路径

os.getcwd()方法获取当前文件所在的目录

os.path.join()方法将当前路径与上两级路径链接起来

image_path:获取到flower_data所在路径

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train", # 获取训练集的路径
                                     transform=data_transform["train"])  # 训练预处理
train_num = len(train_dataset)  # 打印训练集有多少张照片

5.加载数据集分类文件 

{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4} :数据集共分为五类
flower_list = train_dataset.class_to_idx 获取分类的名称所对应的索引值
cla_dict = dict((val, key) for key, val in flower_list.items()) 将字典中键与值的位置对换

?为什么要换位置

=>这样在预测后可以直接通过值给到我们最后的测试类别
json_str = json.dumps(cla_dict, indent=4) :将字典编码成json格式
with open('class_indices,json', 'w') as json_file:
        json_file.write(json_str)  :将键值对保存到json文件中,方便后续在预测时读取信息

下面是生成的json文件

# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
    json_file.write(json_str)

 6.载入测试集

代码大致与LeNet网络差不多,载入测试集的图片路径需要自己定义并进行预处理。

在使用matplotlib查看图片时,注意修改为batch_size=4,shuffle=True参数。

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,
                                              shuffle=False, num_workers=0)

暂时的全部代码,训练集还没有完全实现,我后续会补充上的,因为课真的是太多了。

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)


# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
    json_file.write(json_str)


batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4,
                                              shuffle=True, num_workers=0)

学习碎碎念:

学习的道路上总会是遇到困难和麻烦的,不要心急,不要烦躁,一步一步的解决问题,慢慢来总会好的!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509390.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电子价签前景璀璨,汉朔科技革新零售行业的数字化新篇章

新型商超模式数字化“秘密武器”——电子价签 传统纸质价签,只要商品价格、日期等信息发生变化,就必须重新打印进行手动替换。电子价签的应用使传统的人工申请、调价、打印、营业员去货架前端更换等变价流程均可省略,所有门店的价格由后台统…

Linux命令深入学习——列出帮助手册,开机关机

linux中有多种方法查看一个不熟悉命令的详细信息,如 ls --help,help ls,man ls,info ls 在linux系统中可以使用命令进行开关机以及相关基础操作 同时在进行写入操作时,可以使用快捷键进行操作

图论(二)之最短路问题

最短路 Dijkstra求最短路 文章目录 最短路Dijkstra求最短路栗题思想题目代码代码如下bellman-ford算法分析只能用bellman-ford来解决的题型题目完整代码 spfa求最短路spfa 算法思路明确一下松弛的概念。spfa算法文字说明:spfa 图解: 题目完整代码总结ti…

【LeetCode每日一题】2129. 将标题首字母大写

文章目录 [2129. 将标题首字母大写](https://leetcode.cn/problems/capitalize-the-title/)思路:代码: 2129. 将标题首字母大写 思路: 1.先根据空格,将每个单词切割,依次遍历 2.用StringBuilder来对结构进行拼接 3.…

element plust的表格 el-table数据不按列展示

ElementPlus的表格demo代码放到原生的html <template><el-table :data"tableData" style"width: 100%"><el-table-column prop"date" label"Date" width"180" /><el-table-column prop"name"…

使用Python查询和下载Sentinel卫星数据

欢迎学习本教程,了解如何使用 Python 访问和下载 Sentinel 卫星数据。在深入探讨技术方面之前,让我们先了解一下哨兵卫星是什么以及它们为何如此重要。 哨兵家族。资料来源:欧空局。 Sentinel 卫星是欧洲航天局 (ESA) 开发的一组地球观测任务,是哥白尼计划的一部分,该计划…

Spark性能优化指南——高级篇

调优概述 有的时候&#xff0c;我们可能会遇到大数据计算中一个最棘手的问题——数据倾斜&#xff0c;此时Spark作业的性能会比期望差很多。数据倾斜调优&#xff0c;就是使用各种技术方案解决不同类型的数据倾斜问题&#xff0c;以保证Spark作业的性能。 数据倾斜发生时的现…

一文了解Cornerstone3D中窗宽窗位的3种设置场景及原理

&#x1f506; 引言 在使用Cornerstone3D渲染影像时&#xff0c;有一个常用功能“设置窗宽窗位&#xff08;windowWidth&windowLevel&#xff09;”&#xff0c;通过精确调整窗宽窗位&#xff0c;医生能够更清晰地区分各种组织&#xff0c;如区别软组织、骨骼、脑组织等。…

SSM整合项目(校验)

文章目录 1.前端校验1.需求分析2.HomeView.vue的数据池中添加校验规则3.HomeView.vue 绑定校验规则![image-20240311213428771](https://img-blog.csdnimg.cn/img_convert/7770bfa16814a0efd4eb818c9869a5bd.png)4.验证是否生效5.如果验证不通过&#xff0c;阻止用户提交表单1.…

机器学习之分类回归模型(决策数、随机森林)

回归分析 回归分析属于监督学习方法的一种&#xff0c;主要用于预测连续型目标变量&#xff0c;可以预测、计算趋势以及确定变量之间的关系等。 Regession Evaluation Metrics 以下是一些最流行的回归评估指标: 平均绝对误差(MAE):目标变量的预测值与实际值之间的平均绝对差…

webpack5零基础入门-4使用webpack处理less文件

1.安装less npm install less -D 2.创建less文件 .box{width: 100px;height: 100px;background: red; } 3.引入less文件并打包 执行npx webpack 报错无法识别less文件 4.安装less-loader并配置 npm install less-loader9 -D 这里指定一下版本不然会因为node版本过低报错 …

Java 启动参数 -- 和 -D写法的区别

当我们配置启动1个java 项目通常需要带一些参数 例如 -Denv uat , --spring.profiles.activedev 这些 那么用-D 和 – 的写法区别是什么&#xff1f; 双横线写法 其中这种写法基本上是spring 和 spring 框架独有 最常用的无非是就是上面提到的 --spring.profiles.activede…

【golang】28、用 httptest 做 web server 的 controller 的单测

文章目录 一、构建 HTTP server1.1 model.go1.2 server.go1.3 curl 验证 server 功能1.3.1 新建1.3.2 查询1.3.3 更新1.3.4 删除 二、httptest 测试2.1 完整示例2.2 实现逻辑2.3 其他示例2.4 用 TestMain 避免重复的测试代码2.5 gin 框架的 httptest 一、构建 HTTP server 1.1…

如何配置固定TCP公网地址实现远程访问内网MongoDB数据库

文章目录 前言1. 安装数据库2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射2.3 测试随机公网地址远程连接 3. 配置固定TCP端口地址3.1 保留一个固定的公网TCP端口地址3.2 配置固定公网TCP端口地址3.3 测试固定地址公网远程访问 前言 MongoDB是一个基于分布式文件存储的数…

JDK环境变量配置-jre\bin、rt.jar、dt.jar、tools.jar

我们主要看下rt.jar、dt.jar、tools.jar的作用&#xff0c;rt.jar在​%JAVA_HOME%\jre\lib&#xff0c;dt.jar和tools.jar在%JAVA_HOME%\lib下。 rt.jar&#xff1a;Java基础类库&#xff0c;也就是Java doc里面看到的所有的类的class文件。 tools.jar&#xff1a;是系统用来编…

星星魔方

星星魔方 1&#xff0c;魔方三要素 &#xff08;1&#xff09;组成部件 6个中心块和8个角块和三阶魔方同构&#xff0c;另外每个面还有构成五角星的十个块。 &#xff08;2&#xff09;可执行操作 一共12种操作&#xff0c;其中6种是每个层顺时针旋转90度&#xff0c;另外6…

Gateway(路由映射)

1.SpringCloud Gateway Spring Cloud Gateway组件的核心是一系列的过滤器&#xff0c;通过这些过滤器可以将客户端发送的请求转发(路由)到对应的微服务。 Spring Cloud Gateway是加在整个微服务最前沿的防火墙和代理器&#xff0c;隐藏微服务结点IP端口信息&#xff0c;从而加…

用Vision Pro来控制机器人

【技术框架概述】 - visionOS App + Python Library用于从Vision Pro将头部/手腕/手指跟踪数据流式传输到任何机器人。 【定位】 - 该框架旨在利用Vision Pro控制机器人,并记录用户在环境中导航和操作的方式,以训练机器人。 【核心功能】 1. 提供visionOS应用程序和Py…

TEASEL: A transformer-based speech-prefixed language model

文章目录 TEASEL&#xff1a;一种基于Transformer的语音前缀语言模型文章信息研究目的研究内容研究方法1.总体框图2.BERT-style Language Models&#xff08;基准模型&#xff09;3.Speech Module3.1Speech Temporal Encoder3.2Lightweight Attentive Aggregation (LAA) 4.训练…

大语言模型系列-中文开源大模型

文章目录 前言一、主流开源大模型二、中文开源大模型排行榜 前言 近期&#xff0c;OpenAI 的主要竞争者 Anthropic 推出了他们的新一代大型语言模型 Claude 3&#xff0c;该系列涵盖了三个不同规模的模型&#xff1a;Opus、Sonnet 和 Haiku。 Claude 3声称已经全面超越GPT-4。…