六天带你入门PyTorch深度学习(1/6)
之PyTorch初认识
Pytorch深度学习快速入门简易教程,适合所有新手学习打好框架基础
跟着我的节奏一步一步学,一周即可掌握
跟着我的节奏一步一步学,一周即可掌握
import torch #导入torch库,PyTorch是一个基于Torch的Python开源机器学习库
#查看当前机器的torch版本和cuda版本
print(torch.__version__)
#查看当前机器的gpu是否可用,可用即为True
torch.cuda.is_available()
dir(): 查看函数下的子函数
dir(torch.nn)
help() :查看函数的用法
help(torch.nn.Softmax)
from torch.utils.data import Dataset
import os #处理图像
from PIL import Image #打开图像
Dataset(): 获取数据及其标签
class Mydata(Dataset):
def __init__(self,root_dir,label_dir): #初始化:根据类创建实例时首先要运行的函数,为后面的函数提供全局变量self
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(root_dir,label_dir) #拼接地址
self.img_path=os.listdir(self.path) #该地址下的文件名转化为列表形式
def __getitem__(self, idx): #idx:图像的索引序号
img_name=self.img_path[idx]
img_item_path=os.path.join(self.path,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return (len(self.img_path))
实例化
root_dir="../datasets/hymenoptera_data/train"
label_dir="ants"
ants_dataset=Mydata(root_dir,label_dir)
img,label=ants_dataset[1]
img,label