Pytorch第一部分数据模块

news2024/10/1 16:32:30

数据划分:

从数据集中将数据划分为训练集,测试集,验证集

# -*- coding: utf-8 -*-
"""
# @file name  : 1_split_dataset.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 将数据集划分为训练集,验证集,测试集
"""

import os
import random
import shutil


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    random.seed(1)

    dataset_dir = "F:\\depthlearning data\\RMB_data"
    split_dir = "F:\\depthlearning data\\rmb_split"
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))

整体代码:人民币二分类训练,这里只关注数据部分

# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            loss_val_epoch = loss_val / len(valid_loader)
            valid_curve.append(loss_val_epoch)
            # valid_curve.append(loss.item())    # 20191022改,记录整个epoch样本的loss,注意要取平均
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

# ============================ inference ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100
    print("模型获得{}元".format(rmb))






数据部分:

# ============================ step 1/5 数据 ============================

#读取数据路径
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")


norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

#训练集 数据预处理  缩放 裁剪 转换为张量 
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
#验证集 少了裁剪的方法
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例  传入数据路径 数据预处理
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

dataloader分为sampler(索引)和dataset(标签)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1613644.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

20.Unity飞机大战游戏

1任务&#xff1a;使背景图动起来 2任务&#xff1a;飞机换帧动画 3任务&#xff1a;让飞机发射子弹 4任务&#xff1a;敌机出现 5任务&#xff1a;控制飞机 6任务&#xff1a;游戏碰撞逻辑 7任务&#xff1a;另外两种类型的敌机 8任务&#xff1a;拾取奖励物品换枪 9…

230基于matlab的布谷鸟(COA)多目标优化算法

基于matlab的布谷鸟&#xff08;COA&#xff09;多目标优化算法&#xff0c;以 满意度、成本、时间、质量为目标的多目标优化求解代码。程序已调通&#xff0c;可直接运行。 230 matlab 布谷鸟&#xff08;COA&#xff09;多目标优化 - 小红书 (xiaohongshu.com)

操作符不存在:sde.st_geometry ^ !sde.st_geometry建议 SQL函 数st_intersects在内联inlining期间

操作符不存在&#xff1a;sde.st_geometry ^ &#xff01;sde.st_geometry建议 SQL函 数st_intersects在内联inlining期间 问题&#xff1a;最近在使用SQL图形处理函数处理图形时&#xff0c;莫名奇妙报如下错误&#xff0c;甚是费解 于是开始四处"寻医问药" 1、nav…

叶子相似的树

题目链接 叶子相似的树 题目描述 注意点 给定的两棵树结点数在 [1, 200] 范围内给定的两棵树上的值在 [0, 200] 范围内 解答思路 深度优先遍历按顺序找到两棵树各自的叶子节点并存储到两个list中&#xff0c;随后比较两个list是否相同即可 代码 /*** Definition for a b…

mysql基础20——数据备份

数据备份 数据备份有2种 一种是物理备份 一种是逻辑备份 物理备份 物理备份 通过把数据文件复制出来 达到备份的目的 用得比较少 逻辑备份 逻辑备份 把描述数据库结构和内容的信息保存起来 达到备份的目的 是免费的 数据备份工具 mysqldump &#xff08;3种模式&#x…

力扣283. 移动零

Problem: 283. 移动零 文章目录 题目描述思路复杂度Code 题目描述 思路 1.定义一个int类型变量index初始化为0&#xff1b; 2.遍历nums当当前的元素nums[i]不为0时使nums[i]赋值给nums[index]&#xff1b; 3.从index开始将nums中置对应位置的元素设为0&#xff1b; 复杂度 时间…

【数据结构-树和二叉树-森林-哈夫曼树】

目录 1 树1.1 树的描述&#xff08;基本术语&#xff09; 2 二叉树&#xff08;树的度最大为2&#xff09;2.1 注意事项-五种基本形态2.2 二叉树的抽象数据类型定义 3 二叉树的性质3.1 两种特殊形式的二叉树-重点会计算3.2 题目练习&#xff1a; 4 二叉树的存储结构4.1 顺序存储…

卷积神经网络(CNN)基础

目录 卷积神经网络介绍 卷积神经网络原理 卷积层&#xff1a;通过在原始图片上平移来提取特征 激活层&#xff1a;增加非线性分割能力 池化层polling&#xff08;下采样层&#xff09;&#xff1a;减少学习参数&#xff0c;去掉不重要的样本&#xff0c;降低网络的复杂度 卷…

面试(06)————MySQL篇

目录 问题一&#xff1a;在MySQL中&#xff0c;如何定位慢查询&#xff1f; 方案一&#xff1a;开源工具 方案二&#xff1a;MySQL自带慢日志 模拟面试 问题二&#xff1a;这个SQL语句执行很慢&#xff0c;如何分析的呐&#xff1f; 模拟面试 问题三&#xff1a;了解过索引…

【GlobalMapper精品教程】076:基于高程和影像数据创建电子沙盘(真实三维地形)

影像与数字高程模型叠加,可以构建三维真是地形。本文讲解在Globalmapper中基于高程和影像数据创建电子沙盘(真实三维地形)。 文章目录 一、加载数据二、创建三维网格三、三维叠加显示一、加载数据 本实验的数据(配套实验数据资料包中的data076.rar,订阅专栏,获取全文及数…

Java Web 网页设计(1)

不要让追求之舟停泊在幻想的港湾 而应扬起奋斗的风帆 驶向现实生活的大海 网页设计 1.首先 添加框架支持 找到目录右键添加 找到Web Application选中 点击OK 然后 编辑设置 找到Tomcat--local 选中 点击OK 名称可以自己设置 找到对应文件夹路径 把Tomcat添加到项目里面 因为…

分享几个申请免费SSL证书的平台

随着数字网络蓬勃发展&#xff0c;人们在享受互联网时代带来的便利生活外&#xff0c;网络安全问题也是日益变得严重&#xff1b;越来越多企业或个人选择通过安装SSL证书来保护网站的数据安全和提高企业的品牌形象&#xff0c;好在很多证书服务机构都有提供免费SSL证书申请的服…

【机器学习-15】决策树(Decision Tree,DT)算法介绍:原理与案例实现

前言 决策树算法是机器学习领域中的一种重要分类方法&#xff0c;它通过树状结构来进行决策分析。决策树凭借其直观易懂、易于解释的特点&#xff0c;在分类问题中得到了广泛的应用。本文将介绍决策树的基本原理&#xff0c;包括熵和信息熵的相关概念&#xff0c;以及几种经典的…

Modern CSV for Mac:强大的CSV文件编辑器

Modern CSV for Mac是一款功能强大的CSV文件编辑器&#xff0c;专为Mac用户设计&#xff0c;提供直观易用的界面和丰富的功能&#xff0c;使用户能够轻松编辑和管理CSV文件。 Modern CSV for Mac v2.0.6激活版下载 这款软件支持快速导入和导出CSV文件&#xff0c;方便用户与其他…

Docker - Compose

原文地址&#xff0c;使用效果更佳&#xff01; Docker - Compose | CoderMast编程桅杆Docker - Compose 在部署应用时&#xff0c;常常使用到不止一个容器&#xff0c;那么在部署容器的时候就需要一个一个进行部署&#xff0c;这样的部署过程也相对来说比较繁琐复杂&#xff…

​「Python大数据」VOC数据清洗

前言 本文主要介绍通过python实现数据清洗、脚本开发、办公自动化。读取voc数据,存储新清洗后的voc数据数据。 一、业务逻辑 读取voc数据采集的数据批处理,使用jieba进行分词,去除停用词,清洗后的评论存储到新的列中保存清洗后的数据到新的Excel文件中二、具体产出 三、执…

实验 2--创建数据库和表

文章目录 实验 2--创建数据库和表实验目的3.3.2 实验准备3.3.3 实验内容2.在 SSMS 图形界面中创建和删除数据库和数据表。(1)在 SSMS 图形界面中创建 YGKQ 数据库;(2)在 SSMS 图形界面中删除 YGKQ 数据库;(3)在 SSMS 图形界面中创建、删除 BMXX表;(4)在 SSMS 图形界面中分别创建…

【探讨】RocketMQ消息灰度方案-消息逻辑隔离

vivo 鲁班平台 RocketMQ 消息灰度方案 - 稀土掘金分布式- vivo鲁班RocketMQ平台的消息灰度方案MQ消息在生产环境和灰度环境隔离一般怎么实现?消息隔离的原则 中心正常消费者,可以同时消费正常的消息和特定标签的消息(自动识别);特定标签的消费者,只能消费特定标签的消息。灰…

连连看游戏页面网站源码

首页&#xff0c;可以上传自己喜欢的图片 游戏页面 通关页面

Python程序设计 字典

教学案例十 字典 1. 判断出生地 sfz.txt文件中存储了地区编码和地区名称 身份证的前6位为地区编码&#xff0c;可以在sfz.txt文件中查询到地区编号对应的地区名称 编写程序&#xff0c;输入身份证号&#xff0c;查询并显示对应的地区名称 若该地区编码不在文件中&#xff0c;…