文章目录
1. 导入相关库 2. 加载数据集 3. 整理数据集 4. 图像增广 5. 读取数据 6. 微调预训练模型 7. 定义损失函数和评价损失函数 9. 训练模型
1. 导入相关库
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
2. 加载数据集
- 该数据集是完整数据集的小规模样本
d2l. DATA_HUB[ 'dog_tiny' ] = ( d2l. DATA_URL + 'kaggle_dog_tiny.zip' ,
'0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d' )
demo = True
if demo:
data_dir = d2l. download_extract( 'dog_tiny' )
else :
data_dir = os. path. join( '..' , 'data' , 'dog-breed-identification' )
3. 整理数据集
def reorg_dog_data ( data_dir, valid_ratio) :
labels = d2l. read_csv_labels( os. path. join( data_dir, 'labels.csv' ) )
d2l. reorg_train_valid( data_dir, labels, valid_ratio)
d2l. reorg_test( data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data( data_dir, valid_ratio)
4. 图像增广
transform_train = torchvision. transforms. Compose( [
torchvision. transforms. RandomResizedCrop( 224 , scale= ( 0.08 , 1.0 ) , ratio= ( 3.0 / 4.0 , 4.0 / 3.0 ) ) ,
torchvision. transforms. RandomHorizontalFlip( ) ,
torchvision. transforms. ColorJitter( brightness= 0.4 , contrast= 0.4 , saturation= 0.4 ) ,
torchvision. transforms. ToTensor( ) ,
torchvision. transforms. Normalize(
[ 0.485 , 0.456 , 0.406 ] , [ 0.229 , 0.224 , 0.225 ]
)
] )
transform_test = torchvision. transforms. Compose( [
torchvision. transforms. Resize( 256 ) ,
torchvision. transforms. CenterCrop( 224 ) ,
torchvision. transforms. ToTensor( ) ,
torchvision. transforms. Normalize(
[ 0.485 , 0.456 , 0.406 ] , [ 0.229 , 0.224 , 0.225 ]
)
] )
5. 读取数据
train_ds, train_valid_ds = [
torchvision. datasets. ImageFolder(
os. path. join( data_dir, 'train_valid_test' , folder) ,
transform= transform_train
) for folder in [ 'train' , 'train_valid' ]
]
valid_ds, test_ds = [
torchvision. datasets. ImageFolder(
os. path. join( data_dir, 'train_valid_test' , folder) ,
transform= transform_test
) for folder in [ 'valid' , 'test' ]
]
train_iter, train_valid_iter = [
torch. utils. data. DataLoader(
dataset, batch_size, shuffle= True , drop_last= True
) for dataset in ( train_ds, train_valid_ds)
]
valid_iter = torch. utils. data. DataLoader(
valid_ds, batch_size, shuffle= False , drop_last= True
)
test_iter = torch. utils. data. DataLoader(
test_ds, batch_size, shuffle= False , drop_last= True
)
6. 微调预训练模型
def get_net ( devices) :
finetune_net = nn. Sequential( )
finetune_net. features = torchvision. models. resnet34( weights= torchvision. models. ResNet34_Weights. IMAGENET1K_V1)
finetune_net. output_new = nn. Sequential(
nn. Linear( 1000 , 256 ) ,
nn. ReLU( ) ,
nn. Linear( 256 , 120 )
)
finetune_net = finetune_net. to( devices[ 0 ] )
for param in finetune_net. features. parameters( ) :
param. requires_grad = False
return finetune_net
get_net( devices= d2l. try_all_gpus( ) )
7. 定义损失函数和评价损失函数
loss = nn. CrossEntropyLoss( reduction= 'none' )
def evaluate_loss ( data_iter, net, device) :
l_sum, n = 0.0 , 0
for features, labels in data_iter:
features, labels = features. to( device[ 0 ] ) , labels. to( device[ 0 ] )
outputs = net( features)
l = loss( outputs, labels)
l_sum += l. sum ( )
n += labels. numel( )
return ( l_sum / n) . to( 'cpu' )
定义训练函数
def train ( net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay) :
net = nn. DataParallel( net, device_ids= devices) . to( devices[ 0 ] )
trainer = torch. optim. SGD(
( param for param in net. parameters( ) if param. requires_grad) ,
lr= lr, momentum= 0.9 , weight_decay= wd
)
scheduler = torch. optim. lr_scheduler. StepLR( trainer, lr_period, lr_decay)
num_batches, timer = len ( train_iter) , d2l. Timer( )
legend = [ 'train loss' ]
if valid_iter is not None :
legend. append( 'valid loss' )
animator = d2l. Animator( xlabel= 'epoch' , xlim= [ 1 , num_epochs] , legend= legend)
for epoch in range ( num_epochs) :
metric = d2l. Accumulator( 2 )
for i, ( features, labels) in enumerate ( train_iter) :
timer. start( )
features, labels = features. to( devices[ 0 ] ) , labels. to( devices[ 0 ] )
trainer. zero_grad( )
output = net( features)
l = loss( output, labels) . sum ( )
l. backward( )
trainer. step( )
metric. add( l, labels. shape[ 0 ] )
timer. stop( )
if ( i + 1 ) % ( num_batches // 5 ) == 0 or i == num_batches - 1 :
animator. add(
epoch + ( i + 1 ) / num_batches, ( metric[ 0 ] / metric[ 1 ] , None )
)
measures = f'train loss { metric[ 0 ] / metric[ 1 ] : .3f } '
if valid_iter is not None :
valid_loss = evaluate_loss( valid_iter, net, devices)
animator. add( epoch + 1 , ( None , valid_loss. detach( ) . cpu( ) ) )
scheduler. step( )
if valid_iter is not None :
measures += f', valid loss { valid_loss: .3f } '
print ( measures + f'\n { metric[ 1 ] * num_epochs / timer. sum ( ) : .1f } '
f'examples/sec on { str ( devices) } ' )
9. 训练模型
devices, num_epochs, lr, wd = d2l. try_all_gpus( ) , 10 , 1e-4 , 1e-4
lr_period, lr_decay, net, = 2 , 0.9 , get_net( devices)
import time
start = time. perf_counter( )
train( net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
end = time. perf_counter( )
print ( f'运行耗时 { ( end- start) : .4f } ' )