PS:基于小土堆视频学习https://www.bilibili.com/video/BV1hE411t7RN?p=6&vd_source=22926f91481026cd10af799bb45e448b
1、Dateset
Dateset就是我们的目标数据,告诉我们如何获取数据,距离:从多种类型的数据中,提取某一类数据,并且可对数据定义编号;
(提供一种数据获取方式及其label)
2、DateLoader
DateLoader:可以对一堆数据进行打包,为网络提供不同的数据形式
神经网络会对数据迭代多次,通常情况的下,数据集分:验证数据集和训练数据集;
from torch.utils.data import Dataset
import cv2
from PIL import Image #需要注意区分大小写
import os #获取到所有图片的地址
class MyDate(Dataset): #继承Dataset类别
def __init__(self,root_dir,label_dir): #初始化,为整个class提供全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
#获取蚂蚁这个文件夹中的所有图片地址
def __getitem__(self, idx): #获取其中的每一个
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) #每一个图片的位置
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self): #确定这个数据集到底有多长
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyDate(root_dir,ants_label_dir)
bees_dataset = MyDate(root_dir,bees_label_dir)
train_dataset = ants_dataset+bees_dataset
#将2个数据集合并为一个,在真实训练,可以用与数据集不足的补充
基于控制台检验
C:\Anaconda3\envs\pytorch_test\python.exe "C:/Program Files/JetBrains/PyCharm Community Edition 2023.1/plugins/python-ce/helpers/pydev/pydevconsole.py" --mode=client --host=127.0.0.1 --port=4406
import sys; print('Python %s on %s' % (sys.version, sys.platform))
sys.path.extend(['H:\\Python\\Test'])
Python 3.10.14 | packaged by Anaconda, Inc. | (main, May 6 2024, 19:44:50) [MSC v.1916 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.25.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 8.25.0
Python 3.10.14 | packaged by Anaconda, Inc. | (main, May 6 2024, 19:44:50) [MSC v.1916 64 bit (AMD64)] on win32
from torch.utils.data import Dataset
import cv2
from PIL import Image #需要注意区分大小写
import os #获取到所有图片的地址
class MyDate(Dataset): #继承Dataset类别
def __init__(self,root_dir,label_dir): #初始化,为整个class提供全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
#获取蚂蚁这个文件夹中的所有图片地址
def __getitem__(self, idx): #获取其中的每一个
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) #每一个图片的位置
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self): #确定这个数据集到底有多长
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = MyDate(root_dir,ants_label_dir)
ants_dataset[0]
Out[3]: (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512>, 'ants')
img,label = ants_dataset[0]
img.show()
img,label = ants_dataset[6]
img.show()
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyDate(root_dir,ants_label_dir)
bees_dataset = MyDate(root_dir,bees_label_dir)
img,label = bees_dataset[6]
img.show()
train_dataset = ants_dataset+bees_dataset
len(ants_dataset)
Out[12]: 124
len(bees_dataset)
Out[13]: 121
len(train_dataset)
Out[14]: 245
img,label = train_dataset[124]
img.show()
img,label = train_dataset[123]
img.show()
可以发现展示不同的图片,124是蜜蜂,123是蚂蚁
transform在dataset中很常用,他主要用于图像的变换,对图像同一个尺寸
3、tensorBoard的使用,在训练模型中很有用
可以通过他,看他的loss的降低情况,模型是不是符合预期,他的训练结果是不是可以。
使用tensorBoard对模型训练很有用,通过他可以有效的看到模型的输出
from torch.utils.tensorboard import SummaryWriter
writer =SummaryWriter("logs")
# writer.add_image()
for i in range(100):
writer.add_scalar("y=x",i,i)
writer.close()
运行后会出现一个logs的文件夹,里面是tensorboard的事件文件
然后再pycharm的local里面打开,就可以查看事假:
tensorboard --logdir logs
显示当前的端口是6006;
但是为了避免端口冲突,可以指定端口
tensorboard --logdir logs --port=6007
单击上述6007的网址,即可生成该图
修改代码为y=2x,2i
再次修改y=2x,3i
则会出现以下现象
为了避免这种问题的出现,可以删除logs下面的事件;如下所示:
删除事件文件夹后,再次在本地中运行
tensorboard --logdir logs --port=6007
单击网址,即可恢复正常