数学推导
手撕代码
import torch
import torch. nn. functional as F
def binary_cross_entropy ( predictions, targets) :
loss = - torch. mean( targets * torch. log( predictions) + ( 1 - targets) * torch. log( 1 - predictions) )
return loss
def categorical_cross_entropy ( predictions, targets) :
predictions = F. softmax( predictions, dim= 1 )
targets = F. one_hot( targets, num_classes= predictions. shape[ 1 ] ) . float ( )
loss = - torch. mean( torch. sum ( targets * torch. log( predictions + 1e - 9 ) , dim= 1 ) )
return loss
if __name__ == "__main__" :
predictions_binary = torch. randn( 10 , 1 , requires_grad= True )
targets_binary = torch. randint( 0 , 2 , ( 10 , 1 ) ) . float ( )
loss_binary = binary_cross_entropy( torch. sigmoid( predictions_binary) , targets_binary)
print ( f"Binary Cross Entropy Loss: { loss_binary. item( ) } " )
predictions_categorical = torch. randn( 10 , 3 , requires_grad= True )
targets_categorical = torch. randint( 0 , 3 , ( 10 , ) )
loss_categorical = categorical_cross_entropy( predictions_categorical, targets_categorical)
print ( f"Categorical Cross Entropy Loss: { loss_categorical. item( ) } " )