Pytorch读取自己的数据集

news2024/9/24 11:31:05

数据集

在这里插入图片描述

流程图

  1. 导包
  2. 设置tfs
  3. 创建datasets.ImageFolder
  4. 创建torch.utils.data.DataLoader()
import time
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

# 获取计算硬件
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

from torchvision import transforms

# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

# 数据集文件夹路径
dataset_dir = "C:/Users/Administrator/Desktop/ResNet/data/xiaobo"

train_path = os.path.join(dataset_dir, 'train_img')
test_path = os.path.join(dataset_dir, 'test_img')
print('训练集路径', train_path)
print('测试集路径', test_path)

from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)

BATCH_SIZE = 32

# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4
                         )

# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=4
                        )


print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系:类别 到 索引号
train_dataset.class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
print(idx_to_labels)
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1677962.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UV:展uv

1.3dmax 选中物体 修改器列表选中“UV展开” 打开UV编辑器 断开圆圈 同理断开瓶底 展开侧面 剥离 拉直 排列 纹理盘格 用于查看排列位置 渲染UV模板 PS图片 将不要的部分填充为黑色 复制图片 删除多余 保存图片 添加材质球和位图 按M打开材质球编辑器 重置UV 将uv变为初始…

激光切割机价格多少钱一台?

随着科技的飞速发展,激光切割技术在制造业中的应用越来越广泛。它以高精度、高效率和高质量著称,是金属加工行业的理想选择。然而,对于初次接触或打算购买激光切割机的用户来说,最关心的问题之一就是价格。那么,激光切…

Google Veo发布:AI生成视频的重大突破

在Google I/O 2024大会上,Google推出了Veo,这是一款能够根据文本提示生成1080p视频的AI模型。这次发布标志着Google在生成式AI领域的又一重大突破。 Veo的强大功能 Veo不仅能够生成各种视觉和电影风格的视频片段,包括风景镜头和延时摄影&am…

RPKI资源公共密钥基础架构体系的搭建

Ubuntu系统下RPKI体系的搭建 Ubuntu安装Nginx 一、安装 apt-get update apt-get install nginx nginx -v #查看安装版本二、目录说明 /usr/sbin/nginx:主程序,启动文件 /etc/nginx:存放配置文件 /var/www/html:存放项目目录 …

python数据分析——数据可视化(图形绘制基础)

数据可视化(图形绘制基础) 前言一、图形绘制基础Matplotlib简介使用过程sin函数示例 二、常用图形绘制折线图的绘制plot示例 散点图的绘制plot示例 柱状图的绘制bar示例 箱型图绘制plot.box示例 饼状图的绘制pie示例 三、图形绘制的组合情况多个折线图的…

Google在我的网站显示不同的SEO元标题/描述

Rank Math使您可以比以往更轻松地为您的博客文章、页面和其他自定义帖子类型编写完美的SEO元标题和描述。但正如您可能已经注意到的那样,谷歌(以及其他搜索引擎)经常不简单地选择使用您设置的元描述,并且这种情况正变得越来越普遍…

rocketmq的基础概念

生产者 生产者生产的过程: producer会在接入nameserver时,获取所有topic和队列的信息,然后在每次发送时,根据负载均衡在topic中选择发送的队列。 生产者的消息是发送给具体的queue,而消费者消费是从具体的queue消费 …

Git系列:git add 被忽视的操作技巧

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

C语言——栈和队列

文章目录 一、栈1. 栈的概念2. 栈的基本功能3. 栈的实现 二、 队列1. 队列的概念2. 队列的基本功能3. 队列的实现 一、栈 1. 栈的概念 栈是一种特殊的线性表,限定仅在表尾进行插入和删除的线性表。这一端称之为栈顶,另一端称为栈底。 栈又称为后进先出…

python将两张图片对齐

目录 需要对齐的照片如下: 源码: 结果: 需要对齐的照片如下: 源码: import cv2 import numpy as np from matplotlib import pyplot as plt# 读取两张图片 imgA cv2.imread(./out/out/3.png) imgB cv2.imread(./…

工具:资源包提取

1.提取unity资源包的工具 一定要通过文件夹的方式选择unity文件否则导出来后的资源不完整

锚点组件--支持点击、滚动高亮锚点

实现一个锚点组件&#xff0c;页面滚动时高亮当前位置锚点、点击锚点时跳转到指定冒点位置&#xff0c;同时选中锚点也高亮 效果图 父组件 import ./index.less; import Anchor from ./Anchor; import Content from ./Content;export default function index() {return (<…

rocketmq的流程

生产过程 消费过程 存储 在RocketMQ中&#xff0c;一个Broker的所有Topic的消息都会被写入到同一个CommitLog文件中。 每个队列&#xff08;Queue&#xff09;都有对应的ConsumeQueue文件。 ConsumeQueue每个记录定长&#xff0c;20字节&#xff0c;消息在commitlog中的偏移量…

【软件的安装与基本设置】AD21软件的PCB规则设置

在绘制PCB之前&#xff0c;要进行规则的创建&#xff0c;因为在绘制PCB的过程中&#xff0c;难免会出现很多错误&#xff0c;所以需要先对绘制PCB创建规则&#xff0c;即所有的打孔&#xff0c;走线&#xff0c;铺铜都要基于电气性能规则去设计&#xff0c;等到后期&#xff0c…

[vue] nvm

nvm ls // 看安装的所有node.js的版本nvm list available // 查显示可以安装的所有node.js的版本可以在可选列表里。选择任意版本安装&#xff0c;比如安装16.15.0 执行&#xff1a; nvm install 16.15.0安装好了之后。可以执行&#xff1a; …

云服务器修改端口通常涉及几个步骤

云服务器修改端口通常涉及几个步骤 远程连接并登录到Linux云服务器&#xff1a; 使用SSH工具&#xff08;如PuTTY、SecureCRT等&#xff09;远程连接到云服务器。 输入云服务器的IP地址、用户名和密码&#xff08;或密钥&#xff09;进行登录。 修改SSH配置文件&#xff1a…

智能数据提取:在严格数据治理与安全标准下的实践路径

一、引言 随着信息技术的飞速发展&#xff0c;数据已成为企业最宝贵的资产之一。然而&#xff0c;数据量的爆炸式增长和数据格式的多样化&#xff0c;使得传统的数据提取方法变得效率低下且难以满足业务需求。智能数据提取技术应运而生&#xff0c;它通过应用人工智能和机器学…

Unity里的Time

Time and frame rate management Time类&#xff1a; Time script reference page. 一些常见的属性有&#xff1a; Time.time 返回从游戏开始经历的时间.Time.deltaTime 返回从上帧结束到现在经历的时间&#xff0c;和帧率成反比Time.timeScale 控制时间流逝的因子Time.fixe…

一个制剂生产人眼中的精益管理

精益管理&#xff08;Lean Management&#xff09;是一种通过减少浪费和提高价值创造的方法&#xff0c;广泛应用于各个行业中&#xff0c;包括药品制剂生产领域。 本文&#xff0c;以一个多年从事药品制剂生产的人的角度&#xff0c;从优点、功能以及与其他管理方法的比较等方…

交通灯-设计说明书

设计摘要&#xff1a; 本设计基于单片机技术&#xff0c;旨在实现智能化交通信号控制&#xff0c;并具备夜间模式、禁止通行模式、同行模式切换以及车流量监测功能。通过按键S1和S2实现夜间模式和禁止通行模式的切换&#xff0c;确保夜间交通安全和禁止通行的需要。按键S3和S4…