机器学习-自定义Loss函数

news2024/12/23 11:51:11

1、简介

机器学习框架中使用自定义的Loss函数,

2、应用

(1)sklearn

from sklearn.metrics import max_error
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import Ridge

def custom_loss(y_true, y_pred, **kwargs):
    # Define your custom loss calculation here
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    if y_true.ndim == 1 :
        y_true = y_true.reshape((-1, 1))

    if y_pred.ndim == 1:
        y_pred = y_pred.reshape((-1, 1))   
        
    loss = max(y_true-y_pred)
    return loss


data = pd.DataFrame(np.array([[i for i in range(0,300)],[i for i in range(100,400)],[i for i in range(200,500)]]).T,columns=['a','b','c'])

X_train ,y_train = data[['a','b']],data[['c']]
clf = Ridge()

custom_scorer = make_scorer(custom_loss, greater_is_better=False)

# Create and train a model using the custom loss function
# model = Ridge()
scores = cross_val_score(clf, X_train, y_train, cv=5, scoring=custom_scorer)

输出是cv=5,交叉验证的5个结果,评估模型

(2)pycaret

from pycaret.regression import *
from pycaret.datasets import get_data
import pandas as pd
import numpy as np
from sklearn.metrics import max_error
from sklearn.metrics import make_scorer


def custom_loss(y_true, y_pred, **kwargs):
    # Define your custom loss calculation here
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    if y_true.ndim == 1 :
        y_true = y_true.reshape((-1, 1))

    if y_pred.ndim == 1:
        y_pred = y_pred.reshape((-1, 1))   
        
    loss = max(y_true-y_pred)
    return loss

# # load sample dataset
# # data = get_data('insurance')
data = pd.DataFrame(np.array([[i for i in range(0,300)],[i for i in range(100,400)],[i for i in range(200,500)]]).T,columns=['a','b','c'])
s = setup(data, target='c')
# custom_loss = make_scorer(custom_loss)
add_metric('custom_loss', 'Custom Loss', custom_loss)
best = compare_models()
predict_model(best)

 

3、深度学习框架

(1)torch

import torch
import torch.nn as nn
import numpy as np
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def get_x_y():
    np.random.seed(0)
    x = np.random.randint(0, 50, 300)
    y_values = 2 * x + 21
    x = np.array(x, dtype=np.float32)
    y = np.array(y_values, dtype=np.float32)
    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)
    return x, y


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)  # 输入的个数,输出的个数

    def forward(self, x):
        out = self.linear(x)
        return out


class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, x, y):
        mse_loss = torch.mean(torch.pow((x - y), 2))
        return mse_loss


if __name__ == '__main__':
    input_dim = 1
    output_dim = 1
    x_train, y_train = get_x_y()

    model = LinearRegressionModel(input_dim, output_dim)
    epochs = 1000  # 迭代次数
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    # model_loss = nn.MSELoss() # 使用MSE作为loss
    model_loss = CustomLoss()  # 自定义loss
    # 开始训练模型
    for epoch in range(epochs):
        epoch += 1
        # 注意转行成tensor
        inputs = torch.from_numpy(x_train)
        labels = torch.from_numpy(y_train)
        # 梯度要清零每一次迭代
        optimizer.zero_grad()
        # 前向传播
        outputs: torch.Tensor = model(inputs)
        # 计算损失
        loss = model_loss(outputs, labels)
        # 返向传播
        loss.backward()
        # 更新权重参数
        optimizer.step()
        if epoch % 50 == 0:
            print('epoch {}, loss {}'.format(epoch, loss.item()))

参看:pytorch自定义loss损失函数_python_脚本之家

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/852232.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flowable-泳池泳道

目录 说明视频教程 说明 流程图描述一个过程的步骤,当这个过程涉及许多不同的人,部门或功能区域时,很难跟踪每 个步骤的负责人。解决此问题的一个有用方法是把流程图分栏,BPMN 中提供了泳池、泳道来支持 这种场景。泳池泳道在流程…

机器学习(十八):Bagging和随机森林

全文共10000余字,预计阅读时间约30~40分钟 | 满满干货(附数据及代码),建议收藏! 本文目标:理解什么是集成学习,明确Bagging算法的过程,熟悉随机森林算法的原理及其在Sklearn中的各参数定义和使用方法 代码…

国内什么牌子的ipad手写笔好用?适合绘画电容笔推荐

对于那些想要用ipad来学习的人来说,苹果Pencil是必不可少的。但是,Apple Pencil的价格真的太贵了,以至于很多人都买不起。所以,最好的办法就是选用一支平替的电容笔。本人从前几年就开始使用iPad,同时本身也是一位数码…

图 ML 中的去噪扩散生成模型

Denoising Diffusion Generative Models in Graph ML | by Michael Galkin | Towards Data Science (medium.com) 一、说明 AI DDPM 代表【"Adaptive Importance Density Power Mixture Model" 】即“自适应重要性密度幂混合模型”,是一种用于密度估计的机…

检测代理IP匿名程度的实用方法

在当今数字化的世界中,使用代理IP已成为保护个人隐私和增强网络安全的常见做法。然而,不同代理IP的匿名程度各异,有些可能具有较高的匿名性,而另一些则可能暴露了用户的真实身份和位置。 因此,了解如何检测代理IP的匿…

动态内存空间管理

欢迎来到我的 世界 ^ _ ^希望作者的文章对你有所帮助,有不足的地方还请指正,大家一起学习交流 ! 文章目录 前言:动态内存是什么一、动态内存介绍:动态内存有关函数介绍1.malloc和free2.calloc函数3.realloc函数 二、一些常见的动态…

虹科案例 | 台积电为保证光罩运输质量选择MSR冲击振动记录仪!内含台积电工程师专访

晶圆运输需要注意什么? 晶圆运输是半导体制造过程中极为关键和敏感的一环。在晶圆运输过程中,需要注意以下几点: 1、静电防护 晶圆非常容易受到静电的干扰,因此在运输过程中需要遵守严格的静电防护措施。使用适当的静电防护包装…

数据库|同城双中心 DR Auto-Sync 主中心意外故障恢复

一、前言 最近,我一直在各个地方进行 TiDB 的 Poc 测试。在这些测试中,客户特别关注同城双中心或者两地三中心的架构体系,经常会找我了解 TiDB 灾备架构的实现方案和底层逻辑。基于客户对 RPO 0 的要求,我一般会向他们介绍 DR Au…

Flume原理剖析

一、介绍 Flume是一个高可用、高可靠,分布式的海量日志采集、聚合和传输的系统。Flume支持在日志系统中定制各类数据发送方,用于收集数据;同时,Flume提供对数据进行简单处理,并写到各种数据接受方(可定制&…

软件安全测试包含哪些内容和方法?安全测试报告的必要性

软件安全测试是一种通过模拟真实攻击的方式,对软件系统进行全面的安全性评估和测试,以发现潜在的安全漏洞和弱点,是确保软件系统安全性的重要措施。在进行软件安全测试时,我们需要了解测试的内容和方法,以及为什么进行…

《高性能MySQL》——查询性能优化(笔记)

文章目录 六、查询性能优化6.1 查询为什么会慢6.2 慢查询基础:优化数据访问6.2.1 是否向数据库请求了不需要的数据查询不需要的记录多表关联时返回全部列总是取出全部列重复查询相同的数据 6.2.2 MySQL 是否在扫描额外的记录响应时间扫描的行数与返回的行数扫描的行…

项目经理和PMO如何穿越低谷,激活自己与团队——WOOP给你答案

2023年,已经还剩下不到5个月了。因为今年整体大环境不好,很多人会因为遇到各种问题,让自己掉入低谷,也有可能让自己带的团队毫无生气。我期待这篇文章能够给你带来向上的力量,在困境中看到希望与可能性。 相信有很多人…

如何在轻量级RTSP服务支持H.264扩展SEI发送接收自定义数据?

为什么开发轻量级RTSP服务? 开发轻量级RTSP服务的目的是为了解决在某些场景下用户或开发者需要单独部署RTSP或RTMP服务的问题。这种服务的优势主要有以下几点: 便利性:通过轻量级RTSP服务,用户无需配置单独的服务器,…

无涯教程-Perl - formline函数

描述 格式功能和相关的运算符使用此功能。它根据PICTURE的内容将LIST格式化为输出累加器变量$^ A。写入完成后,该值将写出到文件句柄中。 语法 以下是此函数的简单语法- formline PICTURE, LIST返回值 该函数总是返回1。 Perl 中的 formline函数 - 无涯教程网无涯教程网提…

中小企业在数字化转型上所面对的问题都有哪些?_光点科技

随着科技的飞速发展,数字化转型已经成为企业持续发展的必由之路。尤其是中小企业,数字化转型不仅可以提高效率,降低成本,还可以拓展市场,增强竞争力。然而,数字化转型并非一帆风顺,中小企业在这…

arcgis更改图层字段名脚本

话不多说,上脚本源码,复制黏贴即可 #-*- coding:utf-8 -*- __author__ lumen import arcpy #输入图层 InputFeature arcpy.GetParameterAsText(0) #原始字段 oldField arcpy.GetParameterAsText(1) # 获取原始字段类型 oldFieldType desc arcpy.…

电机基础知识::(1、电磁力;2力与运动)

永磁同步电机基础知识(一)_哔哩哔哩_bilibili

led台灯哪些牌子性价比高?推荐几款性价比高的护眼台灯

作为学龄期儿童的家长,最担心的就是孩子长时间学习影响视力健康。无论是上网课、写作业、玩桌游还是陪伴孩子读绘本,都需要一个足够明亮的照明环境,因此选购一款为孩子视力发展保驾护航的台灯非常重要。为大家推荐几款性价比高的护眼台灯。 …

influxDB

文章目录 版本2.0 数据结构Organization 组织Bucket 存储桶Measurementtagfieldtimestamp retention policy (RP) 保留策略Point 一条数据Series 一组数据 写入gzip压缩 查询FluxInfluxQL 官网 https://docs.influxdata.com/v1.8 中文翻译文档 https://influxdb-v1-docs-cn.cno…

iPhone手机怎么恢复出厂设置(详解)

如果您的iPhone遇到了手机卡顿、软件崩溃、内存不足或者忘记手机解锁密码等问题,恢复出厂设置似乎是万能的解决方法。 什么是恢复出厂设置?简单来说,就是让手机重新变成一张白纸,将手机所有数据都进行格式化,只保留原…