6 图片增广
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from d2l import torch as d2l
from torch import nn
from PIL import Image
import liliPytorch as lp
from torch.utils.data import Dataset, DataLoader
plt.figure('cat')
img = Image.open('../limuPytorch/images/cat.jpg')
plt.imshow(img)
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
"""
img: 输入的图像。
aug: 增强函数,接受一个图像作为输入并返回一个增强后的图像。
num_rows: 显示增强后图像的行数,默认值为2。
num_cols: 显示增强后图像的列数,默认值为4。
scale: 显示图像的缩放比例,默认值为1.5。
"""
Y = [aug(img) for _ in range(num_rows * num_cols)]
d2l.show_images(Y, num_rows, num_cols, scale=scale)
apply(img,torchvision.transforms.RandomHorizontalFlip())
apply(img,torchvision.transforms.RandomVerticalFlip())
shape_aug = torchvision.transforms.RandomResizedCrop(
(200,200),
scale=(0.1,1),
ratio=(0.5,2),
)
apply(img,shape_aug)
apply(img,
torchvision.transforms.ColorJitter(
brightness=0.5,
contrast=0,
saturation=0,
hue=0
)
)
apply(img,
torchvision.transforms.ColorJitter(
brightness=0,
contrast=0,
saturation=0,
hue=0.5
)
)
color_aug = torchvision.transforms.ColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.5
)
apply(img, color_aug)
augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)
all_images = torchvision.datasets.CIFAR10(train=True, root="../data",download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)
plt.show()
train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()])
test_augs = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()])
def load_cifar10(is_train, augs, batch_size):
dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
transform=augs, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=is_train, num_workers=4)
return dataloader
net = d2l.resnet18(10, 3)
batch_size = 256
lr=0.001
num_epochs = 10
train_iter = load_cifar10(True, train_augs, batch_size)
test_iter = load_cifar10(False, test_augs, batch_size)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()