权重衰减从零开始实现
% matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
n_train, n_test, num_inputs, batch_size = 20 , 100 , 200 , 5
true_w, true_b = torch. ones( ( num_inputs, 1 ) ) * 0.01 , 0.05
train_data = d2l. synthetic_data( true_w, true_b, n_train)
train_iter = d2l. load_array( train_data, batch_size)
test_data = d2l. synthetic_data( true_w, true_b, n_test)
test_iter = d2l. load_array( test_data, batch_size, is_train= False )
def init_params ( ) :
w = torch. normal( 0 , 1 , size= ( num_inputs, 1 ) , requires_grad= True )
b = torch. zeros( 1 , requires_grad= True )
return [ w, b]
def l2_penalty ( w) :
return torch. sum ( w. pow ( 2 ) ) / 2
def train ( lambd) :
w, b = init_params( )
net, loss = lambda X: d2l. linreg( X, w, b) , d2l. squared_loss
num_epochs, lr = 100 , 0.003
animator = d2l. Animator( xlabel= 'epochs' , ylabel= 'loss' , yscale= 'log' ,
xlim= [ 5 , num_epochs] , legend= [ 'train' , 'test' ] )
for epoch in range ( num_epochs) :
for X, y in train_iter:
l = loss( net( X) , y) + lambd * l2_penalty( w)
l. sum ( ) . backward( )
d2l. sgd( [ w, b] , lr, batch_size)
if ( epoch + 1 ) % 5 == 0 :
animator. add( epoch + 1 , ( d2l. evaluate_loss( net, train_iter, loss) ,
d2l. evaluate_loss( net, test_iter, loss) ) )
print ( 'w的L2范数是:' , torch. norm( w) . item( ) )
train( lambd= 0 )
d2l. plt. show( )
train( lambd= 5 )
d2l. plt. show( )
权重衰减的简洁实现
def train_concise ( wd) :
net = nn. Sequential( nn. Linear( num_inputs, 1 ) )
for param in net. parameters( ) :
param. data. normal_( )
loss = nn. MSELoss( reduction= 'none' )
num_epochs, lr = 100 , 0.003
trainer = torch. optim. SGD( [
{ "params" : net[ 0 ] . weight, 'weight_decay' : wd} ,
{ "params" : net[ 0 ] . bias} ] , lr= lr)
animator = d2l. Animator( xlabel= 'epochs' , ylabel= 'loss' , yscale= 'log' ,
xlim= [ 5 , num_epochs] , legend= [ 'train' , 'test' ] )
for epoch in range ( num_epochs) :
for X, y in train_iter:
trainer. zero_grad( )
l = loss( net( X) , y)
l. mean( ) . backward( )
trainer. step( )
if ( epoch + 1 ) % 5 == 0 :
animator. add( epoch + 1 ,
( d2l. evaluate_loss( net, train_iter, loss) ,
d2l. evaluate_loss( net, test_iter, loss) ) )
print ( 'w的L2范数:' , net[ 0 ] . weight. norm( ) . item( ) )
train_concise( 0 )
train_concise( 5 )