PyTorch深度学习快速入门教程【土堆】基础知识篇

news2025/1/13 7:41:58

Juptyer

版本:

  • Python 3.9.19
  • Pytorch 2.4.1
(pytorch0) C:\Users\25694>conda install nb_conda_kernels
(pytorch0) C:\Users\25694>jupyter notebook

使用conda环境的pytorch:
在这里插入图片描述
成功解决python.exe无法找到程序入口 无法定位程序输入点

shift+enter:运行这个代码块并跳转到下一个代码块

  1. 将环境写入Notebook的kernel中:
python -m ipykernel install --user --name 环境名称 --display-name "Python (环境名称)"
  1. 打开Jupyter notebook,新建Python文件,这时候你就能看见你的创建的环境

Python学习中的两大法宝函数

在这里插入图片描述
在这里插入图片描述
实战操作:

Python交互模式主要有两种:CPython用>>>作为提示符,而IPython用In [序号]:作为提示符。
如果你是>>>,那么可以回ana黑色窗口控制台输入conda install ipython来使其变成in
如果虚拟环境中没有安装 ipython包,那么默认就是>>>模式
如果当前显示IN[序号],想换回>>>,则在File->Setting中取消勾选下列的框,Apply->OK
下列的框在“构建、执行、部署”(我使用了Pycharm里面的汉化插件)→“控制台”→使用IPython
在这里插入图片描述

PyCharm及Jupyter使用及对比

在这里插入图片描述

PyTorch加载数据初认识

在这里插入图片描述
其中train是训练数据集,val是验证数据集

Dataset??

一般数据和对应label有两种形式:

  • 如一个文件夹内存放多个同类的图片:文件夹的名称就是其label
  • 数据和label存放在不同的文件夹内

在这里插入图片描述
在这里插入图片描述

Dataset类代码实战

通常使用pycharm的python console进行一些小的测试!可以方便地查看过程属性

(pytorch0) C:\Users\25694>pip install opencv-python

把hymenoptera_data数据集拷贝到项目文件夹中并重命名
在这里插入图片描述
绝对路径 ctrl+shift+c。但是注意windows下路径要使用两个斜杠,来表示转义

在这里插入图片描述

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset): # 创建class继承自Dataset

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
tran_dataset = ants_dataset + bees_dataset

img, label = bees_dataset[10]

img.show()

os.listdir() 是 Python 中 os 模块的一个函数,主要用于列出指定目录中的所有文件和子目录的名称。它返回一个包含该目录下所有条目(文件或文件夹)名称的列表(但不会递归子目录),但不会包含文件的完整路径,只返回名称(仅列出名称,不会指明是文件还是目录。如果需要判断某个条目是文件还是目录,可以结合 os.path.isfile() 和 os.path.isdir() 一起使用。)。参数:path:需要列出内容的目录路径。可以是相对路径或绝对路径。如果不传递 path 参数,则默认列出当前工作目录(即 .)。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
Python中self用法详解

TensorBoard的使用

ctrl+点按SummaryWriter:
在这里插入图片描述

SummaryWriter:直接向log_dir文件夹写入事件文件,这个事件文件可以被TensorBoard解析。需要输入一个文件夹的名称,不输入的话默认文件夹为runs/CURRENT_DATETIME_HOSTNAME

log_dir:tensorboard文件的存放路径 flush_secs:表示写入tensorboard文件的时间间隔
其他的参数当前并不重要,需要的话可以自己看看。

标量只有大小概念,没有方向的概念。通过一个具体的数值就能表达完整。比如:重量、温度、长度、提及、时间、热量等都数据标量。

安装tensorboard:

(pytorch0) C:\Users\25694>
conda list
conda search numpy
conda install numpy=1.23.1
conda install tensorboard

add_scalar()方法的使用(常用来绘制train/val loss)

def add_scalar(
        self,
        tag,
        scalar_value,
        global_step=None,
        walltime=None,
        new_style=False,
        double_precision=False,
    ):

添加一个标量数据到 Summary 当中,需要参数

  • tag:Data指定方式,类似于图表的title
  • scalar_value:需要保存的数值(y轴)
  • global_step:训练到多少步(x轴)

在这里插入图片描述

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

# writer.add_imgae()
# y = 2x

for i in range(100):
    writer.add_scalar("y=2x", 2 * i, i)

writer.close()

在这里插入图片描述
如何打开生成的事件文件:
注意路径!

(pytorch0) E:\PyCharmProjects\learn_torch>tensorboard --logdir=logs --port=6007
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.17.0 at http://localhost:6007/ (Press CTRL+C to quit)

这里不指定端口port 默认是6006
在这里插入图片描述

在writer中写入新事件,还有上个事件
解决方法:删除logs文件夹下的所有事件,重新运行程序,在terminal中按ctrl+c退出,再按上键打开端口

add_image()方法的使用(常用来观察训练结果)

准备:将练手数据集里的解压到项目目录下新建的data文件夹中

def add_image(self, tag, img_tensor, global_step=None):
  • tag:对应图像的title
  • img_tensor:图像的数据类型,只能是torch.Tensor、numpy.array、string/blobnaem
  • global_step:训练步骤,int 类型

在这里插入图片描述

# 打开控制台,其位置就是项目文件夹所在的位置
# 故只需复制相对地址
 
image_path = "data/train/ants_image/0013035.jpg"
 
from PIL import Image
img = Image.open(image_path)
print(type(img))

PIL.格式不符合要求。

在这里插入图片描述
因此,利用opencv(numpy.array())读取图片,对PIL图片进行转换,活动numpy型图片数据

import numpy as np
img_array=np.array(img)
print(type(img_array))   # numpy.ndarray

在Python控制台输出图片类型:
在这里插入图片描述
从PIL到numpy,需要在add_image()中指定shape中每一个数字/维表示的含义

img_tensor默认的图片尺寸格式为(3,H,W),但是一般我们的图片格式为(H,W,3),因此需要对图片格式进行调整

通过print(img_array.shape)以查看img是否为C(通道)H(高度)W(宽度)的形式

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image

writer = SummaryWriter("logs")
image_path = "data/train/ants_image/0013035.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape) # (512, 768, 3) (H, W, C)(高度,宽度,通道)

writer.add_image("test", img_array, 1, dataformats="HWC")
# y = 2x

for i in range(100):
    writer.add_scalar("y=5x", 5 * i, i)

writer.close()

在这里插入图片描述

在一个title下,通过滑块显示每一步的图形,可以直观地观察训练中给model提供了哪些数据,或者想对model进行测试时,可以看到每个阶段的输出结果
如果想要单独显示,重命名一下title即可,即 writer.add_image() 的第一个字符串类型的参数

Tensorforms的使用-主要是对图片进行变换

transforms是torchvision下的一个工具箱,用于格式转化,视觉处理工具,不用于文本
在这里插入图片描述
图片经过transforms工具的变换,得到我们想要的一个图像变换结果

解释:根据模具创造工具,使用具体工具根据说明进行输入和输出

按住ctrl,点击transforms

conda install torchvision

在这里插入图片描述
在这里插入图片描述

它里面有多个工具类:

  • Compose类:结合不同的transforms
  • ToTensor类:将PIL和numpy类型的图片转为Tensor(可用于训练)
  • ToPILImage类:把一个图片转换成PIL Image
  • Normalize类:归一化,标准化,用来对数据预处理
  • Resize类:尺寸变换
  • CenterCrop类:中心裁剪
  • Regularize类:正则化,防止模型过拟合的技术
  • RandomCrop:随机裁剪。

工具类都有__ call __()方法,具体作用看python中的 call()

在 Python 中,call 是一个特殊的方法,可以让对象像函数一样被调用。换句话说,如果一个类实现了 call 方法,那么它的实例就能像调用普通函数一样被调用。
例子:

class MyClass:
    def __init__(self, value):
        self.value = value

    def __call__(self, x):
        return self.value * x

# 创建类的实例
obj = MyClass(10)

# 调用实例,像调用函数一样
result = obj(5)  # 等价于 obj.__call__(5)

print(result)  # 输出 50

两个问题
python的用法 ——> tensor数据类型
通过 transforms.ToTensor去解决两个问题

  1. Transforms该如何使用
  2. Tensor数据类型与其他图片数据类型有什么区别?为什么需要Tensor数据类型
from PIL import Image
from torchvision import transforms
 
# 绝对路径 D:\PycharmProjects\pythonProject\pytorchlearn\data\train\ants_image\0013035.jpg
# 相对路径 data/train/ants_image/0013035.jpg
img_path="data/train/ants_image/0013035.jpg"   #用相对路径,绝对路径里的\在Windows系统下会被当做转义符
# img_path_abs="C:\Users\11842\Desktop\Learn_torch\data\train\ants_image\0013035.jpg",双引号前加r表示转义
 
img = Image.open(img_path)   #Image是Python中内置的图片的库
print(img)  # PIL类型

问题一:

transforms 该如何使用(python)

从transforms中选择一个class,对它进行创建,对创建的对象传入图片,即可返回出结果

ToTensor将一个 PIL Image 或 numpy.ndarray 转换为 tensor的数据类型

# 1、Transforms该如何使用
tensor_trans = transforms.ToTensor()  #从工具箱transforms里取出ToTensor类,返回tensor_trans对象
tensor_img = tensor_trans(img)   #创建出tensor_trans后,传入其需要的参数,即可返回结果。返回一个tensor类型的图片
print(tensor_img)

在这里插入图片描述

ctrl+p提示函数参数

问题二:

为什么我们需要 Tensor 数据类型

在Python Console输入:

from PIL import Image
from torchvision import transforms
 
img_path= "data/train/ants_image/0013035.jpg"  
img = Image.open(img_path)   
 
tensor_trans = transforms.ToTensor() 
tensor_img = tensor_trans(img)  

Tensor 数据类型包装了反向神经网络所需要的一些理论基础的参数,如:_backward_hooks、_grad等(先转换成Tensor数据类型,再训练)

在这里插入图片描述
下载opencv:

python版本要和opencv版本相对应,否则安装的时候会报错

查看链接:Links for opencv-python
在这里插入图片描述

pip install opencv-python==3.4.11.45

两种读取图片的方式

  1. PIL Image
from PIL import Image
img_path = "xxx"
img = Image.open(img_path)
img.show()
  1. numpy.ndarray(通过opencv)
import cv2
cv_img=cv2.imread(img_path)

上节课以 numpy.array 类型为例,这节课使用 torch.Tensor 类型:

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
 
# python的用法 ——> tensor数据类型
# 通过 transforms.ToTensor去解决两个问题
# 1、Transforms该如何使用
# 2、Tensor数据类型与其他图片数据类型有什么区别?为什么需要Tensor数据类型
 
# 绝对路径 C:\Users\11842\Desktop\Learn_torch\data\train\ants_image\0013035.jpg
# 相对路径 data/train/ants_image/0013035.jpg
img_path="data/train/ants_image/0013035.jpg"   #用相对路径,绝对路径里的\在Windows系统下会被当做转义符
# img_path_abs="C:\Users\11842\Desktop\Learn_torch\data\train\ants_image\0013035.jpg",双引号前加r表示转义
img = Image.open(img_path)   #Image是Python中内置的图片的库
#print(img)
 
writer = SummaryWriter("logs")
 
# 1、Transforms该如何使用
tensor_trans = transforms.ToTensor()  #从工具箱transforms里取出ToTensor类,返回tensor_trans对象
tensor_img = tensor_trans(img)   #创建出tensor_trans后,传入其需要的参数,即可返回结果
#print(tensor_img)
 
writer.add_image("Tensor_img",tensor_img)  # .add_image(tag, img_tensor, global_step)
# tag即名称
# img_tensor的类型为torch.Tensor/numpy.array/string/blobname
# global_step为int类型
 
writer.close()

常见的transforms

图片有不同的格式,打开方式也不同

图片格式打开方式
PILImage.open() ——Python自带的图片打开方式
tensorToTensor()
narrayscv.imread() ——Opencv

Compose的使用

把不同的 transforms 结合在一起,后面接一个数组,里面是不同的transforms

Example:图片首先要经过中心裁剪,再转换成Tensor数据类型
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

Python中 call 的用法

class Person:
    def __call__(self, name):
        print("call"+name)
    def hello(self,name):
        print("hi"+name)
person = Person()
# 如果定义了call方法可以对象名(传入参数)来调用
person("zhangsan")
person.hello("lisi")

ToTensor的使用

把 PIL Image 或 numpy.ndarray 类型转换为 tensor 类型(TensorBoard 必须是 tensor 的数据类型)(运行前要先把之前的logs进行删除)

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
 
writer = SummaryWriter("logs")
img = Image.open("images/pytorch.png")
print(img)  # 可以看到类型是PIL
 
# ToTensor的使用
trans_totensor = transforms.ToTensor()  # 将类型转换为tensor
img_tensor = trans_totensor(img)  # img变为tensor类型后,就可以放入TensorBoard当中
writer.add_image("ToTensor", img_tensor)
writer.close()

ToPILImage 的使用

把 tensor 数据类型或 ndarray 类型转换成 PIL Image

Normalize 的使用

用平均值/标准差归一化 tensor 类型的 image(输入)

图片RGB三个信道,将每个信道中的输入进行归一化

output[channel] = (input[channel] - mean[channel]) / std[channel]

设置 mean 和 std 都为0.5,则 output= 2*input -1。如果 input 图片像素值为0~1范围内,那么结果就是 -1 ~1之间

加入step值:
第一步

#Normalize的使用
print(img_tensor[0][0][0])  # 第0层第0行第0列
trans_norm = transforms.Normalize([4,6,7],[3,2,6])  # mean,std,因为图片是RGB三信道,故传入三个数
img_norm = trans_norm(img_tensor)  # 输入的类型要是tensor
print(img_norm[0][0][0])
writer.add_image("Normalize",img_norm,1)#第一步

第二步

#Normalize的使用
print(img_tensor[0][0][0])  # 第0层第0行第0列
trans_norm = transforms.Normalize([2,6,7],[1,2,2])  # mean,std,因为图片是RGB三信道,故传入三个数
img_norm = trans_norm(img_tensor)  # 输入的类型要是tensor
print(img_norm[0][0][0])
writer.add_image("Normalize",img_norm,2)#第二步

Resize 的使用

输入:PIL Image 将输入转变到给定尺寸

序列:(h,w)高度,宽度
一个整数:不改变高和宽的比例,只单纯改变最小边和最长边之间的大小关系。之前图里最小的边将会匹配这个数(等比缩放)

取消首字母匹配:
一般情况下,你需要输入R,才能提示出Resize
我们想设置,即便你输入的是r,也能提示出Resize,也就是忽略了大小写进行匹配提示
File—> Settings—> 搜索case—> Editor-General-Code Completion-去掉Match case前的√—>Apply—>OK
在这里插入图片描述

#Resize的使用
print(img.size)  # 输入是PIL.Image
 
trans_resize = transforms.Resize((512,512))
#img:PIL --> resize --> img_resize:PIL
img_resize = trans_resize(img)  #输出还是PIL Image
 
#img_resize:PIL --> totensor --> img_resize:tensor(同名,覆盖)
img_resize = trans_totensor(img_resize)
 
writer.add_image("Resize",img_resize,0)
print(img_resize)

Compose() 中的参数需要是一个列表,Python中列表的表示形式为[数据1,数据2,…]

在Compose中,数据需要是transforms类型,所以得到Compose([transforms参数1,transforms参数2,…])

#Compose的使用(将输出类型从PIL变为tensor类型,第二种方法)
 
trans_resize_2 = transforms.Resize(512)  # 将图片短边缩放至512,长宽比保持不变
 
# PIL --> resize --> PIL --> totensor --> tensor
#compose()就是把两个参数功能整合,第一个参数是改变图像大小,第二个参数是转换类型,前者的输出类型与后者的输入类型必须匹配
 
trans_compose = transforms.Compose([trans_resize_2,trans_totensor])
img_resize_2 = trans_compose(img)   # 输入需要是PIL Image
writer.add_image("Resize",img_resize_2,1)

RandomCrop的使用

随机裁剪,输入PIL Image

参数size:

  • sequence:(h,w) 高,宽
  • int:裁剪一个该整数×该整数的图像

(1)以 int 为例:

#RandomCrop()的使用
trans_random = transforms.RandomCrop(512)
trans_compose_2 = transforms.Compose([trans_random,trans_totensor])
for i in range(10):  #裁剪10个
    img_crop = trans_compose_2(img)  # 输入需要是PIL Image
    writer.add_image("RandomCrop",img_crop,i)

(2)以 sequence 为例:

#RandomCrop()的使用
trans_random = transforms.RandomCrop((200,500))
trans_compose_2 = transforms.Compose([trans_random,trans_totensor])
for i in range(10):  #裁剪10个
    img_crop = trans_compose_2(img)
    writer.add_image("RandomCropHW",img_crop,i)

touchvision中的数据集使用

pytorch数据集下载

需要学习知识:

  • 如何把数据集(多张图片)和 transforms 结合在一起。
  • 标准数据集如何下载、查看、使用。

各个模块作用

(1)torchvision.datasets

如:COCO 目标检测、语义分割;MNIST 手写文字;CIFAR 物体识别

(2)torchvision.io

输入输出模块,不常用

(3)torchvision.models

提供一些比较常见的神经网络,有的已经预训练好,比较重要,后面会使用到,如分类模型、语义分割模型、目标检测、视频分类等

(4)torchvision.ops

torchvision提供的一些比较少见的特殊的操作,基本不常用

(5)torchvision.transforms

之前讲解过

(6)torchvision.utils

提供一些常用的小工具,如TensorBoard

本节主要讲解torchvision.datasets,以及它如何跟transforms联合使用
在这里插入图片描述

1.数据集如何下载

#如何使用torchvision提供的标准数据集
import torchvision

train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True) #root使用相对路径,会在该.py所在位置创建一个叫dataset的文件夹,同时把数据保存进去。用Ctrl加P查看需要参数。
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

运行后,控制台中就会显示正在下载数据集
在这里插入图片描述
数据集下载过慢时
获得下载链接后,把下载链接放到迅雷中,会首先下载压缩文件tar.gz,之后会对该压缩文件进行解压,里面会有相应的数据集。
采用迅雷下载完毕后,在PyCharm里新建directory,名字也叫dataset,再将下载好的压缩包复制进去,download依然为True,运行后,会自动解压该数据

注意dataset里面不要解压完,要放压缩包,然后在运行这个代码,不然他会重新下载
在这里插入图片描述
实际上首先下载的是下面这个压缩文件,然后会对其进行解压

2.数据集如何查看与使用

注意虽然运行代码时文件中还有上面两行,且download=True,但是会自动校验已下载,也就是说不会产生影响,所以可以放着不管

import torchvision
 
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
 
print(test_set[0])  # 查看测试集中的第一个数据,是一个元组:(img, target)
print(test_set.classes)  # 列表
 
img,target = test_set[0]
print(img)
print(target)  # 输出:3。输出为列表第几个类别。从0开始数,这里类别为cat列表第四个
print(test_set.classes[target])  # cat
img.show()

在 PyTorch 的 torchvision 库中,CIFAR10 数据集是一个继承自 torch.utils.data.Dataset 的类。这个类实现了 __ getitem__ 方法(定义了如何获取单个样本),因此当你访问 test_set[0] 时,它实际上调用的是 __ getitem__,并返回一个元组。

在 Python 中,元组(tuple) 是一种 不可变的 序列类型,它可以存储多个元素。元组与列表(list)非常相似,不同之处在于元组一旦创建就 不能修改,而列表是可变的。元组常用于存储一组不希望被修改的数据。元组中的元素可以是不同类型的(list也可以不同),比如整数、字符串、浮点数等。
元组的创建:元组使用圆括号 () 创建,元素之间用逗号 , 分隔。

# 创建一个包含多个元素的元组
my_tuple = (1, 2, 3, "apple", 3.14)

# 创建单个元素的元组时,需要加上一个逗号
single_element_tuple = (5,)

# 不加逗号,括号会被认为是表达式
not_a_tuple = (5)

元组的用途:元组常用于函数返回多个值的场景,因为可以一次性返回多个元素。

def get_position():
    return (10, 20)

position = get_position()
print(position)  # 输出: (10, 20)

函数返回值的类型

  • 圆括号 ():表示元组。
  • 方括号 []:表示列表。

在这里插入图片描述

这个 __ getitem __ 方法属于一个自定义数据集类,通常用在 PyTorch 或其他类似框架中,用来在索引处返回数据集中的样本(如图像和标签)。这是实现自定义数据集的重要方法之一。以下是代码的详细解析:

index: int:该方法接受一个整数 index,表示要获取的数据样本的索引。
返回类型 Tuple[Any, Any]:它返回一个 元组 (image, target),其中:
image 是图像数据,target 是与该图像对应的类别标签。
Any 在类型注解中表示可以是任意类型,通常 image 会是图像对象,而 target 是整数表示的类别标签。

img, target = self.data[index], self.targets[index]
self.data[index]:从数据集中提取索引为 index 的图像数据。
self.targets[index]:提取与该图像对应的目标标签。target 通常是一个表示类别的整数,类似于分类任务中图像的类别编号。
转换为 PIL 图像:
img = Image.fromarray(img)
Image.fromarray(img):将图像数据从数组格式(通常是 NumPy 数组)转换为 PIL 图像对象,这一步是为了保证返回的数据与其他图像处理流程(如数据增强)兼容。PIL 是 Python Imaging Library 的简称,常用于图像处理。
应用图像变换:
if self.transform is not None:
img = self.transform(img)
如果 self.transform 不为空,则将其应用于图像 img。这是常见的图像预处理步骤,transform 通常是数据增强操作,如旋转、裁剪、归一化等。
应用目标变换:
if self.target_transform is not None:
target = self.target_transform(target)
如果 self.target_transform 不为空,则将其应用于目标 target。目标变换通常用于修改标签的格式,比如将类别标签转换为独热编码,或进行其他处理。
返回值:
return img, target
最后返回一个元组 (img, target),其中 img 是经过可能的转换后的图像,target 是与图像对应的标签。
小结:
这个 __ getitem __ 方法的作用是:

根据给定的索引,从数据集中提取图像和标签。
将图像从数组格式转换为 PIL 图像,以与其他图像数据处理兼容。
根据需要对图像和标签应用预处理(transform 和 target_transform)。
返回处理后的图像和标签,作为一个元组 (image, target)。

3.CIFAR10数据集介绍

​ CIFAR10 数据集包含了6万张32×32像素的彩色图片,图片有10个类别,每个类别有6千张图像,其中有5万张图像为训练图片,1万张为测试图片。

如何把数据集(多张图片)和 transforms 结合在一起
CIFAR10数据集原始图片是PIL Image,如果要给pytorch使用,需要转为tensor数据类型(转成tensor后,就可以用tensorboard了)

transforms 更多地是用在 datasets 里 transform 的选项中

import torchvision
from torch.utils.tensorboard import SummaryWriter
 
#把dataset_transform运用到数据集中的每一张图片,都转为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
 
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True) #root使用相对路径,会在该.py所在位置创建一个叫dataset的文件夹,同时把数据保存进去
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
 
# print(test_set[0])
 
writer = SummaryWriter("logs")
#显示测试数据集中的前10张图片
for i in range(10):
    img,target = test_set[i]
    writer.add_image("test_set",img,i)  # img已经转成了tensor类型
 
writer.close()

Dataloader的使用

Dataloader每次从dataset中取数据
在这里插入图片描述

Dataloader

参数介绍
参数如下(大部分有默认值,实际中只需要设置少量的参数即可):

  • dataset:只有dataset没有默认值,只需要将之前自定义的dataset实例化,再放到dataloader中即可
  • batch_size:每次抓牌抓几张
  • shuffle:设置为True在每个 epoch 重新洗牌数据(默认值:False),但一般用True
  • num_workers:加载数据时采用单个进程还是多个进程,多进程的话速度相对较快,默认为0(主进程加载)。Windows系统下该值>0会有问题(报错提示:BrokenPipeError)
  • drop_last:100张牌每次取3张,最后会余下1张,这时剩下的这张牌是舍去还是不舍去。值为True代表舍去这张牌、不取出,False代表要取出该张牌
import torchvision
from torch.utils.data import DataLoader
 
#准备的测试数据集
test_data = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor)
 
test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
 
#测试数据集中第一张图片及target
img,target = test_data[0]
print(img.shape)
print(target)

输出结果:

torch.Size([3, 32, 32])   #三通道,32×32大小
3   #类别为3

dataset

  • __ getitem() __:return img,target
    dataloader(batch_size=4):从dataset中取4个数据

  • img0,target0 = dataset[0]

  • img1,target1 = dataset[1]

  • img2,target2 = dataset[2]

  • img3,target3 = dataset[3]

把 img 0-3 进行打包,记为imgs;target 0-3 进行打包,记为targets;作为dataloader中的返回

for data in test_loader:
    imgs,targets = data
    print(imgs.shape)
    print(targets)

输出:

torch.Size([4, 3, 32, 32])   #4张图片,三通道,32×32
tensor([0, 4, 4, 8])  #4个target进行一个打包

数据是随机取的(断点debug一下,可以看到采样器sampler是随机采样的),所以两次的 target 0 并不一样

batch_size

对于打包的图片展示,使用的方法是add_images()方法,单张图片展示使用add_image()方法

# 用上节课torchvision提供的自定义的数据集
# CIFAR10原本是PIL Image,需要转换成tensor
 
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
 
# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())
 
# 加载测试集
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
#batch_size=4,意味着每次从test_data中取4个数据进行打包
 
writer = SummaryWriter("dataloader")
step=0
for data in test_loader:
    imgs,targets = data  #imgs是tensor数据类型
    writer.add_images("test_data",imgs,step)
    step=step+1
 
writer.close()

由于 drop_last 设置为 False,所以最后16张图片(没有凑齐64张)显示如下:
在这里插入图片描述

drop_last

若将 drop_last 设置为 True,最后16张图片(step 156)会被舍去,结果如图:

在这里插入图片描述

shuffle

一个 for data in test_loader 循环,就意味着打完一轮牌(抓完一轮数据),在下一轮再进行抓取时,第二次数据是否与第一次数据一样。值为True的话,会重新洗牌(一般都设置为True)

在外面再套一层 for epoch in range(2) 的循环

shuffle为False的话两轮取的图片顺序是一样的
在这里插入图片描述

# shuffle为True
for epoch in range(2):
    step=0
    for data in test_loader:
        imgs,targets = data  #imgs是tensor数据类型
        writer.add_images("Epoch:{}".format(epoch),imgs,step)
        step=step+1

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2150290.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习-深度学习数据集之打架斗殴识别数据集

关于“打架识别数据集”,这是一个专门设计用于训练计算机视觉模型以识别打架、摔倒以及持械行为的数据集。此类数据集对于开发安全监控系统至关重要,可以帮助在公共场所如学校、酒吧或地铁站等地及时发现潜在的暴力事件,从而快速采取行动来防…

anaconda的windows新手安装及配置教程(适用于物联网工程、计算机专业)

第一步:点击免费下载 点击我直达anaconda官网">——>点击我直达anaconda官网 第二步:跳过注册 第三步:下载windows版本 第四步:安装步骤 1.Next (下一步) 2.I Agree (我同意) 3.默认即可,下一步 4.安装地址可以选到D盘,如果没有默认也行,只是一个…

上传富文本插入文件时报错:JSON parse error: Unexpected character解决办法

方式一(加密解密): 1.前端 (1)安装 crypto-js npm install crypto-js(2)util下创建asc.js asc.js import CryptoJS from crypto-js// 需要和后端一致 const KEY CryptoJS.enc.Utf8.parse(…

《Linux基础》练习操作

一、文件目录类操作 1. 创建新用户user,其中用户名为学生姓名首字小写(如:张三,用户名为zsan) 将/etc/passwd拷贝到/home/user下面。修改/home/user/passwd,在文件的第15行下添加“hello 学号姓名”,光标停留在 hello 学号姓名…

[漏洞复现]泛微e-mobile cdnfile文件读取漏洞分析复现

如果觉得该文章有帮助的,麻烦师傅们可以搜索下微信公众号:良月安全。点个关注,感谢师傅们的支持。 免责声明 本号所发布的所有内容,包括但不限于信息、工具、项目以及文章,均旨在提供学习与研究之用。所有工具安全性…

金属3D打印经济效益高吗?

在我国制造业迈向产业升级的重要阶段,3D打印技术如同一股强劲的新风,特别是在航空航天、汽车、生物医疗等领域,已成为复杂构件制造的“明星”技术。那么,对于众多生产厂家而言,金属3D打印的经济账到底怎么算&#xff1…

永磁同步电机谐波抑制算法(8)——基于神经网络的傻瓜式(无需知道谐波频率)谐波抑制

1.简介 前面的内容已经介绍了很多谐波抑制的方法:多同步、PIR、陷波器等等。也介绍了比较多的谐波来源:死区(5、7、11、13等次相电流谐波)、绕组不对称(基波不等幅值、3次相电流谐波)等等。 上述的方法都…

基于springboot+vue超市管理系统

基于springbootvue超市管理系统 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本无人超市管理系统就是在这样的大环境下诞生,其可以帮助使用者在…

进程间关系与进程守护

一、进程组 1、理解 每一个进程除了有一个进程 ID(PID)之外 还属于一个进程组, 进程组是一个或者多个进程的集合, 一个进程组可以包含多个进程。 每一个进程组也有一个唯一的进程组 ID(PGID), 并且这个 PGID 类似于进程 ID, 同样…

不只是模仿,伯克利新研究赋予机器人跨实体自主学习的能力,零样本时代已来

导读: 在当今科技飞速发展的时代,机器人技术正不断地给我们带来惊喜和变革。2024 年 9 月,一篇来自加州大学伯克利分校、丰田研究所和Physical Intelligence 的研究论文RoVi-Aug: Robot and Viewpoint Augmentation for Cross-Embodiment Rob…

2024/9/20 使用QT实现扫雷游戏

有三种难度初级6x6 中级10x10 高级16x16 完成游戏 游戏失败后&#xff0c;无法再次完成游戏&#xff0c;只能重新开始一局 对Qpushbutton进行重写 mybutton.h #ifndef MYBUTTON_H #define MYBUTTON_H #include <QObject> #include <QWidget> #include <QPus…

基于ACMEv2协议的免费SSL证书申请-支持Let‘s Encrypt/Google/ZeroSSL

项目&#xff1a;https://github.com/cook-code-jazor/acmex 非开源&#xff0c;使用webui管理证书的申请&#xff0c;所有文件本地化存储&#xff0c;支持windows/linux/osx。 证书申请直连ACMEv2服务商&#xff0c;没有任何中间接口&#xff0c;支持Lets Encrypt/Google/Ze…

【HTML5】html5开篇基础(1)

1.❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; Hello, Hello~ 亲爱的朋友们&#x1f44b;&#x1f44b;&#xff0c;这里是E绵绵呀✍️✍️。 如果你喜欢这篇文章&#xff0c;请别吝啬你的点赞❤️❤️和收藏&#x1f4d6;&#x1f4d6;。如果你对我的…

【UE5】将2D切片图渲染为体积纹理,最终实现使用RT实时绘制体积纹理【第二篇-着色器制作】

在上一篇文章中&#xff0c;我们已经理顺了实现流程。 接下来&#xff0c;我们将在UE5中&#xff0c;从头开始一步一步地构建一次流程。 通过这种方法&#xff0c;我们可以借助一个熟悉的开发环境&#xff0c;使那些对着色器不太熟悉的朋友们更好地理解着色器的工作原理。 这篇…

百望云生态伙伴大会在北京、深圳、昆明三地举办,携手共赢数字化未来!

伴随着金税四期数电票、乐企加速扩围&#xff0c;激发了企业大量的财税数字化转型的需求&#xff0c;为财税服务市场注入了前所有未有的活力。2024年7月9日&#xff0c;百望云成功登陆港交所&#xff0c;成为港股“电子发票第一股”&#xff0c;加码财税业务布局&#xff0c;纵…

Spring Boot利用dag加速Spring beans初始化

1.什么是Dag&#xff1f; 有向无环图(Directed Acyclic Graph)&#xff0c;简称DAG&#xff0c;是一种有向图&#xff0c;其中没有从节点出发经过若干条边后再回到该节点的路径。换句话说&#xff0c;DAG中不存在环路。这种数据结构常用于表示并解决具有依赖关系的问题。 DAG的…

生信初学者教程(一):欢迎

文章目录 配套数据R包版本安装包版权答疑在生物信息学(生信)领域,随着高通量测序技术的不断发展,大量数据涌现,为科研工作者提供了丰富的资源。然而,对于初学者而言,如何从海量的数据中挖掘有价值的信息,并开展一个完整的生信项目,仍然是一个挑战。目前,市面上针对初…

网络层协议 ——— IP协议

文章目录 概念协议头格式分片与组装网段划分IP地址的数量限制私有IP和公有IP路由 概念 IP协议&#xff08;Internet Protocol&#xff09;是互联网上使用的一种网络协议&#xff0c;也是互联网的基础协议之一。它属于TCP/IP体系中的网络层协议&#xff0c;主要负责将数据包从源…

OpenCV特征检测(5)检测图像中的角点函数cornerMinEigenVal()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 计算用于角点检测的梯度矩阵的最小特征值。 该函数类似于 cornerEigenValsAndVecs&#xff0c;但它计算并存储协方差矩阵导数的最小特征值&…

尚硅谷javaweb笔记

1、基本概念 1.1、前言 web开发&#xff1a; web&#xff0c;网页的意思&#xff0c;www.baidu.com 静态web html,css 提供给所有人看的数据始终不会发生变化&#xff01; 动态web 淘宝&#xff0c;几乎是所有的网站&#xff1b; 提供给所有人看的数据始终会发生变化&…