PyTorch代码实战入门

news2024/9/23 13:21:16

人这辈子千万不要马虎两件事

一是找对爱人、二是选对事业

因为太阳升起时要投身事业

太阳落山时要与爱人相拥

一、准备数据集

蚂蚁蜜蜂数据集

蚂蚁蜜蜂的图片,文件名就是数据的label

二、使用Dataset加载数据

打开pycharm,选择Anaconda创建的pytorch环境

将数据集放在项目的根目录下,并修改文件名

 新建 read_data.py文件,编写如下代码

from torch.utils.data import Dataset
from PIL import Image
import os


class MyData(Dataset):

    # 获取训练数据的list
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    # 获取每一张图片及其label
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)


root_dir = "dataset/train"
ant_label_dir = "ants"
bee_label_dir = "bees"
ants_dataset = MyData(root_dir, ant_label_dir)
bees_dataset = MyData(root_dir, bee_label_dir)

train_dataset = ants_dataset + bees_dataset

解释:

1. MyData继承Dataset类

2. 重写里面的__init__方法,在MyData类初始化时,通过拼接数据路径,os.listdir()加载出数据的list列表

3. 重写里面的__getitem__方法,将一张一张地将list的图片和对应的label加载出来

4.train_dataset = ants_dataset + bees_dataset,得到训练数据集

三、TensorBoard的使用

安装TensorBoard需要的包

pip install tensorboard

编写如下代码:

from torch.utils.tensorboard import SummaryWriter

# 指定log文件生成的位置
writer = SummaryWriter("logs")

for i in range(100):
    '''
    第一个参数:图像的title
    第二个参数:纵坐标的值 
    第三个参数:横坐标的值
    '''
    writer.add_scalar("y=3x", 3 * i, i)

# 关闭资源
writer.close()

运行代码,会在SummaryWriter指定的位置生成log文件

在Terminal运行下面语句:

可以自己指定端口,防止冲突

tensorboard --logdir=logs --port=6007

运行输出

在浏览器打开

使用 writer.add_image 加载图片

编写下面代码:

from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np

# 指定log文件生成的位置
writer = SummaryWriter("logs")

image_path = "dataset/train/ants/7759525_1363d24e88.jpg"
image_PIL = Image.open(image_path)
image_array = np.array(image_PIL)

'''
第一个参数:图像的title
第二个参数:图片的numpy值
第三个参数:步数
第四个参数:将图片进行转换,3通道放在前面
'''
writer.add_image("test", image_array, 1, dataformats='HWC')

for i in range(100):
    '''
    第一个参数:图像的title
    第二个参数:纵坐标的值 
    第三个参数:横坐标的值
    '''
    writer.add_scalar("y=3x", 3 * i, i)

# 关闭资源
writer.close()

 同样在控制台打开运行生成的日志文件

四、Transforms的使用

图片转换工具

ToTensor() 把 PIL格式或者numpy格式转换成tensor

编写如下代码:

from PIL import Image
from torchvision import transforms

img_path = "dataset/train/ants/0013035.jpg"
img = Image.open(img_path)

# 使用transforms
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)

print(tensor_img)

输出:

tensor([[[0.3137, 0.3137, 0.3137,  ..., 0.3176, 0.3098, 0.2980],
         [0.3176, 0.3176, 0.3176,  ..., 0.3176, 0.3098, 0.2980],
         [0.3216, 0.3216, 0.3216,  ..., 0.3137, 0.3098, 0.3020],
         ...,
         [0.3412, 0.3412, 0.3373,  ..., 0.1725, 0.3725, 0.3529],
         [0.3412, 0.3412, 0.3373,  ..., 0.3294, 0.3529, 0.3294],
         [0.3412, 0.3412, 0.3373,  ..., 0.3098, 0.3059, 0.3294]],

        [[0.5922, 0.5922, 0.5922,  ..., 0.5961, 0.5882, 0.5765],
         [0.5961, 0.5961, 0.5961,  ..., 0.5961, 0.5882, 0.5765],
         [0.6000, 0.6000, 0.6000,  ..., 0.5922, 0.5882, 0.5804],
         ...,
         [0.6275, 0.6275, 0.6235,  ..., 0.3608, 0.6196, 0.6157],
         [0.6275, 0.6275, 0.6235,  ..., 0.5765, 0.6275, 0.5961],
         [0.6275, 0.6275, 0.6235,  ..., 0.6275, 0.6235, 0.6314]],

        [[0.9137, 0.9137, 0.9137,  ..., 0.9176, 0.9098, 0.8980],
         [0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9098, 0.8980],
         [0.9216, 0.9216, 0.9216,  ..., 0.9137, 0.9098, 0.9020],
         ...,
         [0.9294, 0.9294, 0.9255,  ..., 0.5529, 0.9216, 0.8941],
         [0.9294, 0.9294, 0.9255,  ..., 0.8863, 1.0000, 0.9137],
         [0.9294, 0.9294, 0.9255,  ..., 0.9490, 0.9804, 0.9137]]])

常见的Transforms

在机器学习和深度学习中,数据转换(Transforms)是一种常见的操作,用于对数据进行预处理、增强或标准化。下面是一些常见的数据转换操作:

1. 数据标准化(Normalization):将数据按比例缩放,使其落在特定的范围内,通常是将数据映射到0到1之间或者使用均值为0、方差为1的分布(更快收敛)。这可以通过以下方法实现:

- Min - Max标准化:将数据缩放到指定的最小值和最大值之间。

- Z-Score标准化:将数据转化为均值为0、标准差为1的分布。

2. 数据增强(Data Augmentation):用于扩充训练数据集,增强样本的多样性,提高模型的泛华能力。常见的数据增强操作包括:

- 随机裁剪(Random Crop):随机裁剪图像的一部分,以减少位置的依赖性。

- 随机翻转(Random Filp):随机水平或垂直翻转图像,增加数据的多样性。

- 随机旋转(Random Rotation):随机旋转图像的角度,增加数据的多样性。

3. 图像预处理:用于对图像进行预处理,以减少噪声、增强特征或改变图像的外观。一些常见的图像预处理操作包括:

- 图像平滑(Image Smoothing):使用滤波器对图像进行平滑处理,减少噪声。

- 直方图均衡化(Histogram Equalization):增强图像的对比度,使得图像中的像素值更加均匀分布。

- 图像缩放(Image Resizing):改变图像的尺寸,通常用于将图像调整为模型输入的大小。

4. 文本预处理:用于对文本数据进行预处理和清洗,以便更好地适应模型的输入要求。一些常见的文本预处理操作包括:

- 分词(Tokenization):将文本分割成单个的词或字符。

- 去除停用词(Stopword Removal):去除常见的无意义词语,如“a”,“the”等。

- 文本向量化(Text Vectorzation):将文本转换为数值形式,如使用词袋模型或词嵌入。

以上只是一些常见的数据转换操作示例,实际应用中可能会根据任务和数据的特点进行适当的调整和组合。在使用转换操作时,可以使用各种机器学习框架(如PyTorch、TensorFlow)提供的相关库或函数来实现这些操作。

示例代码如下:

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

img_path = "hymenoptera_data/16.jpg"
img = Image.open(img_path)


writer = SummaryWriter("logs")


# totensor使用
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
writer.add_image("lyy", tensor_img)

# Normalize 使用
trans_norm = transforms.Normalize([111,111,111],[10,10,10])
img_norm = trans_norm(tensor_img)
writer.add_image("normalize", img_norm, 2)

# resize
trans_resize = transforms.Resize( (512, 512))
img_resize = trans_resize(img) # 输入的是Image类型的图像
img_resize_tensor = transforms.ToTensor()

tensor_img = tensor_trans(img_resize)
writer.add_image("resize", tensor_img, 0)

# compose 用法
trans_resize_2 =transforms.Resize(123)
trans_compose = transforms.Compose([trans_resize_2, tensor_trans])
img_resize_2 = trans_compose(img)
writer.add_image("resize2", img_resize_2,1)

# randomCrop
trans_random = transforms.RandomCrop((20,50))
trans_compose_2 = transforms.Compose([trans_random, tensor_trans])
for i in range(10):
    img_crop = trans_compose_2(img)
    writer.add_image("randomCrop", img_crop, i)


writer.close()

print("end")

五、torchvision中数据集使用

pytorch提供了很多的数据集,提供给我们学习使用。

进入官网

 选择Dataset数据集

下面就是常用的数据集

 CIFAR10数据集使用示例:

 代码示例:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transfrom = torchvision.transforms.Compose([torchvision.transforms.ToTensor])

train_set = torchvision.datasets.CIFAR10(root="./cifar10", train=True, transform=dataset_transfrom, download=True)
test_set = torchvision.datasets.CIFAR10(root="./cifar10", train=False, transform=dataset_transfrom, download=True)

writer = SummaryWriter("test_torchvision")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_torchvision", img, i)

writer.close()

 提示:这样下载数据集很慢,可以使用迅雷下载, ctrl + CIFAR10  进入类里面,里面有下载地址,如下:

 把下载好的数据集压缩包,放在指定路径下即可。

 然后再tensorboard面板就能看到下载的数据集

六、DataLoader使用

Dataset是整理好的数据集

DataLoader是把这个数据集加载到神经网络中去训练

使用示例:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./cifar10", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

writer = SummaryWriter("dataloader")

step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("dataloader", imgs, step)
    step = step + 1

writer.close()

解释:

通过创建 test_loader,您可以使用 for 循环迭代它来逐批获取测试数据

  • dataset: 指定要加载的数据集test_data
  • batch_size: 指定每个批次中的样本数量。在这里,每个批次中有64个样本。
  • shuffle: 指定是否对数据进行洗牌(随机重排)。如果设置为 True,每个 epoch(训练周期)开始时,数据将被随机打乱顺序。这对于增加数据的随机性、减少模型对输入顺序的依赖性很有用。
  • num_workers: 指定用于数据加载的子进程数量。在这里,设置为0表示在主进程中加载数据,没有额外的子进程参与。如果设置为大于0的值,将使用多个子进程并行加载数据,可以加快数据加载速度。
  • drop_last: 指定当数据样本数量不能被 batch_size 整除时,是否丢弃最后一个不完整的批次。如果设置为 True,最后一个不完整的批次将被丢弃;如果设置为 False,最后一个不完整的批次将保留。

运行之后在tensorboard面板查看

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/823497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FTP Server

简介 FTP:File Transfer Protocol 文件传输协议;它工作在 OSI 模型的第七层, TCP 模型的第四层, 即应用层, 使用 TCP 传输而不是 UDP, 客户在和服务器建立连接前要经过一个“三次握手”的过程,…

捷码低代码|FreeContainer 自由布局组件详解

背景知识: 1、布局组件: 布局组件是一种用于在用户界面中安排和组织其他组件的组件。它们提供了一种简单的方法来控制和管理页面上组件的位置、大小和层次结构。布局组件可以是容器,可以包含其他组件,并确定它们在界面上的显示方式…

【MyBatis】 框架原理

目录 10.3【MyBatis】 框架原理 10.3.1 【MyBatis】 整体架构 10.3.2 【MyBatis】 运行原理 10.4 【MyBatis】 核心组件的生命周期 10.4.1 SqlSessionFactoryBuilder 10.4.2 SqlSessionFactory 10.4.3 SqlSession 10.4.4 Mapper Instances 与 Hibernate 框架相比&#…

深入理解MVVM架构模式

MVVM原理 MVVM是一种用于构建用户界面的软件架构模式,它的名称代表着三个组成部分:Model(模型)、View(视图)和ViewModel(视图模型)。MVVM的主要目标是将应用程序的UI与其底层数据模…

SERDES关键技术

目录 一、SERDES介绍 二、SERDES关键技术 2.1 多重相位技术 2.2 线路编解码技术 2.2.1 8B/10B编解码 2.2.2 控制字符(Control Characters) 2.2.3 Comma检测 2.2.4 扰码(Scrambling) 2.2.5 4B/5B与64B/66B编解码技术 2.3 包传…

Halcon学习之一维测量实战之测量矩形(一)

一、采集图像 (1)测量充电器 测量充电器的引脚,然后每次旋转充电器,让测量矩形都跟着它转,这就是定位+测量, (2)测量钥匙 (3)测量瓶盖 我们后面还会涉及到拟合的问

牛客网Verilog刷题——VL53

牛客网Verilog刷题——VL53 题目答案 题目 设计一个单端口RAM,它有: 写接口,读接口,地址接口,时钟接口和复位;存储宽度是4位,深度128。注意rst为低电平复位。模块的接口示意图如下。 输入输出描…

HDFS的QJM方案

Quorum Journal Manager仲裁日志管理器 介绍主备切换,脑裂问题解决---ZKFailoverController(zkfc)主备切换,脑裂问题解决-- Fencing(隔离)机制主备数据状态同步问题解决 HA集群搭建集群基础环境准备HA集群规…

解决git仓库无效问题

解决fatal: … not valid: is this a git repository?问题 凭证编辑修改成自己的账号密码即可解决

2023年第四届“华数杯”数学建模思路 - 复盘:校园消费行为分析

文章目录 0 赛题思路1 赛题背景2 分析目标3 数据说明4 数据预处理5 数据分析5.1 食堂就餐行为分析5.2 学生消费行为分析 0 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 赛题背景 校园一卡通是集身份认证…

LEARNING TO EXPLORE USING ACTIVE NEURAL SLAM 论文阅读

论文信息 题目:LEARNING TO EXPLORE USING ACTIVE NEURAL SLAM 作者:Devendra Singh Chaplot, Dhiraj Gandhi 项目地址:https://devendrachaplot.github.io/projects/Neural-SLAM 代码地址:https://github.com/devendrachaplot/N…

python 统计所有的 仓库 提交者的提交次数

字典去重 YYDS 然后再写入excel 表 yyds #!/bin/env python3 from git.repo import Repo import os import pandas as pdspath "/home/labstation/workqueue/sw" url "git10.0.128.128" date [str(x) for x in range(202307, 202308)] datefmt "%…

用html+javascript打造公文一键排版系统11:改进单一附件说明排版

一、用htmljavascript打造公文一键排版系统10中的一个bug 在 用htmljavascript打造公文一键排版系统10:单一附件说明排版 中,我们对附件说明的排版函数是: function setAtttDescFmt(p) {var t p;var a ;if (-1 ! t.indexOf(:))//是半角冒…

SQL注入之sqlmap

SQL注入之sqlmap 6.1 SQL注入之sqlmap安装 sqlmap简介: sqlmap是一个自动化的SQL注入工具,其主要功能是扫描,发现并利用给定的URL的SQL注入漏洞,目前支持的数据库是MS-SQL,MYSQL,ORACLE和POSTGRESQL。SQLMAP采用四种独特的SQL注…

Moonbeam:开发者的多链教科书

了解波卡的技术架构,只需掌握3个关键词: Relay Chain(中继链):Polkadot将自身视作多核计算机,承载区块链底层安全架构的辐射中心。Parachain(平行链):在“Layer 0”架构…

现货白银投资中的头寸是什么

头寸是现货白银市场上的一个投资术语。建立头寸就是建仓的意思,投资者所持有的头寸也叫敞口。投资如果看涨做多,就是持有多头头寸,如果看跌做空,就持有空头头寸。计算交易的头寸的大小并不复杂,关键是在于投资者要设定…

Linux(New)---历史与虚拟机安装CentOS7.6

前言 其实之前已经学过一遍Linux了,但是感觉学的不够深入和成体系(某节的教学视频不完整),所以这次打算完整的跟一遍韩顺平老师的Linux课程,Linux从入门到精通,就从现在开始! Linux历史概述 L…

【音频分离】demucs V3的环境搭建及训练(window)

文章目录 一、环境搭建(1)新建虚拟环境,并进入(2)安装pyTorch(3)进入代码文件夹,批量安装包(4)安装其他需要的包 二、数据集准备(1)下…

flask中的flask-login

flask中的flask-login 在 Flask 中,用户认证通常是通过使用扩展库(例如 Flask-Login、Flask-HTTPAuth 或 Flask-Security)来实现的。 本文详细地解释下 Flask 中的用户认证。这里是用 Flask-Login 插件为例,这是一个处理用户会话…

count(列名) ,count(1)与count(*) 有何区别?

Mysql版本:8.0.26 可视化客户端:sql yog 文章目录 一、Mysql之count函数简介二、count(列名) ,count(常量)与count(*) 有何区别?2.1 统计字段上的区别2.2 执行效率上的区别 一、Mysql之count函数简介 👉表达式 COUNT(…