因果推断6--多任务学习(个人笔记)

news2024/11/16 17:52:13

目录

1多任务学习

1.1问题描述

1.2数据集

1.3网络结构

1.4结果

2因果推断使用多任务方式

2.1DRNet

2.2Dragonet

2.3Deep counterfactual networks with propensity-dropout

2.4VCNet

3思考


1多任务学习

keras-mmoe/census_income_demo.py at master · drawbridge/keras-mmoe · GitHub

推荐系统-(16)多任务学习:谷歌MMOE原理与实践 - 知乎

1.1问题描述

近年来,深度神经网络的应用越来越广,如推荐系统。推荐系统通常需要同时优化多个目标,如电影推荐中不仅需要预测用户是否会购买,还需要预测用户对于电影的评分,在比如电商领域同时需要预测物品的点击率CTR和转化率CVR。因此,多任务学习模型成为研究领域的一大热点。

1.2数据集

  • Example demo of running the model with the census-income dataset from UCI
    • This dataset is the same one in Section 6.3 of the paper

1.3网络结构

1.4结果

2因果推断使用多任务方式

采用多任务学习方式学习因果关系,尤其是多调研推荐系统的多任务学习模式,进行相应的补充。

2.1DRNet

Learning Counterfactual Representations for Estimating Individual Dose-Response Curves

  1. L1 base layers的参数参与所有数据集的训练,L2 treatment layers的参数只参与Treatment组样本的训练
  2. 能够应用于更加复杂的干预场景下,离散状态干预+连续状态干预,对于每一种干预组合,分别使用head网络进行学习
  3. 我们举个通俗易懂的case,我们想试验不同药剂对不同病人的影响。t=0~k-1分别代表不同组别病人,t=0是正常组,t=1~k-1 分别代表糖尿病人组,高血压病人组以及其它病人组,药剂量级m分为a,b,c分别代表低剂量/中剂量/高剂量,分别对t和m的不同组合采用head网络学习。每个处理层进一步细分为E个头部层(上面只显示了t = 0处理的一组E = 3个头部层)。

2.2Dragonet

Adapting Neural Network for the Estimation of Treatment Effects

  • dragonNet(学习非线性关系):两阶段方法,先学习表示模型,在学习推断模型

如果倾向分的网络丢掉之后, 这个网络结构就是和TARNET的结构相同,后面做了和这种方法的试验对比。 这个loss 中有倾向分部分这个部分会导致网络权重对于g(x) 相关性差的特征自动权重降低,有利于进行特征选择。 下面引入target regularizaiton 进行loss的改进。

2.3Deep counterfactual networks with propensity-dropout

摘要: 我们提出了一种从观察数据推断治疗(干预)的个体化因果效应的新方法。我们的方法将因果推断概念化为一个多任务学习问题;我们使用一个深度多任务网络,在事实和反事实结果之间有一组共享层,以及一组特定于结果的层,为受试者的潜在结果建模。通过倾向-退出正则化方案缓解了观察数据中选择偏差的影响,其中网络通过依赖于相关倾向分数的退出概率对每个训练示例进行减薄。该网络在交替阶段进行训练,在每个阶段中,我们使用两个潜在结果之一(处理过的和控制过的人群)的训练示例来更新共享层和各自特定结果层的权重。基于真实世界观察研究的数据进行的实验表明,我们的算法优于最先进的算法。

代码:GitHub - Shantanu48114860/Deep-Counterfactual-Networks-with-Propensity-Dropout: Implementation of the paper "Deep Counterfactual Networks with Propensity-Dropout"(https://arxiv.org/pdf/1706.05966.pdf) in pytorch framework

  1. 模型采用多目标建模思想,将Treatment组和Control组样本放在同一个模型中,降低模型冗余
  2. 左边部分是多目标框架,Treatment组和Control组的样本有共享层和各自独立的网络层,从而来学习Treatment模型和Control模型
  3. 右边Propensity Network网络主要控制左边模型的复杂度,如果数据好分,通过生成Dropout-Propensity控制左边模型,让其简单些;如果数据不好分,则控制左边模型复杂些
  4. 训练的时候Treatment组和Control组的样本分开训练,迭代次数是奇数时,训练Treatment组样本;偶数时,训练Control组样本

网络:

如果一个参数requires_grad=False,并且这个参数在optimizer里面,则不对它进行更新,并且程序不会报错

network.hidden1_Y1.weight.requires_grad = False
                        
import torch
import torch.nn as nn
import torch.optim as optim

from DCN import DCN


class DCN_network:
    def train(self, train_parameters, device):
        epochs = train_parameters["epochs"]
        treated_batch_size = train_parameters["treated_batch_size"]
        control_batch_size = train_parameters["control_batch_size"]
        lr = train_parameters["lr"]
        shuffle = train_parameters["shuffle"]
        model_save_path = train_parameters["model_save_path"].format(epochs, lr)
        treated_set = train_parameters["treated_set"]
        control_set = train_parameters["control_set"]

        print("Saved model path: {0}".format(model_save_path))

        treated_data_loader = torch.utils.data.DataLoader(treated_set,
                                                          batch_size=treated_batch_size,
                                                          shuffle=shuffle,
                                                          num_workers=1)

        control_data_loader = torch.utils.data.DataLoader(control_set,
                                                          batch_size=control_batch_size,
                                                          shuffle=shuffle,
                                                          num_workers=1)
        network = DCN(training_flag=True).to(device)
        optimizer = optim.Adam(network.parameters(), lr=lr)
        lossF = nn.MSELoss()
        min_loss = 100000.0
        dataset_loss = 0.0
        print(".. Training started ..")
        print(device)
        for epoch in range(epochs):
            network.train()
            total_loss = 0
            train_set_size = 0

            if epoch % 2 == 0:
                dataset_loss = 0
                # train treated
                network.hidden1_Y1.weight.requires_grad = True
                network.hidden1_Y1.bias.requires_grad = True
                network.hidden2_Y1.weight.requires_grad = True
                network.hidden2_Y1.bias.requires_grad = True
                network.out_Y1.weight.requires_grad = True
                network.out_Y1.bias.requires_grad = True

                network.hidden1_Y0.weight.requires_grad = False
                network.hidden1_Y0.bias.requires_grad = False
                network.hidden2_Y0.weight.requires_grad = False
                network.hidden2_Y0.bias.requires_grad = False
                network.out_Y0.weight.requires_grad = False
                network.out_Y0.bias.requires_grad = False

                for batch in treated_data_loader:
                    covariates_X, ps_score, y_f, y_cf = batch
                    covariates_X = covariates_X.to(device)
                    ps_score = ps_score.squeeze().to(device)

                    train_set_size += covariates_X.size(0)
                    treatment_pred = network(covariates_X, ps_score)
                    # treatment_pred[0] -> y1
                    # treatment_pred[1] -> y0
                    predicted_ITE = treatment_pred[0] - treatment_pred[1]
                    true_ITE = y_f - y_cf
                    if torch.cuda.is_available():
                        loss = lossF(predicted_ITE.float().cuda(),
                                     true_ITE.float().cuda()).to(device)
                    else:
                        loss = lossF(predicted_ITE.float(),
                                     true_ITE.float()).to(device)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                dataset_loss = total_loss

            elif epoch % 2 == 1:
                # train controlled
                network.hidden1_Y1.weight.requires_grad = False
                network.hidden1_Y1.bias.requires_grad = False
                network.hidden2_Y1.weight.requires_grad = False
                network.hidden2_Y1.bias.requires_grad = False
                network.out_Y1.weight.requires_grad = False
                network.out_Y1.bias.requires_grad = False

                network.hidden1_Y0.weight.requires_grad = True
                network.hidden1_Y0.bias.requires_grad = True
                network.hidden2_Y0.weight.requires_grad = True
                network.hidden2_Y0.bias.requires_grad = True
                network.out_Y0.weight.requires_grad = True
                network.out_Y0.bias.requires_grad = True

                for batch in control_data_loader:
                    covariates_X, ps_score, y_f, y_cf = batch
                    covariates_X = covariates_X.to(device)
                    ps_score = ps_score.squeeze().to(device)

                    train_set_size += covariates_X.size(0)
                    treatment_pred = network(covariates_X, ps_score)
                    # treatment_pred[0] -> y1
                    # treatment_pred[1] -> y0
                    predicted_ITE = treatment_pred[0] - treatment_pred[1]
                    true_ITE = y_cf - y_f
                    if torch.cuda.is_available():
                        loss = lossF(predicted_ITE.float().cuda(),
                                     true_ITE.float().cuda()).to(device)
                    else:
                        loss = lossF(predicted_ITE.float(),
                                     true_ITE.float()).to(device)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                dataset_loss = dataset_loss + total_loss

            print("epoch: {0}, train_set_size: {1} loss: {2}".
                  format(epoch, train_set_size, total_loss))

            if epoch % 2 == 1:
                print("Treated + Control loss: {0}".format(dataset_loss))
                # if dataset_loss < min_loss:
                #     print("Current loss: {0}, over previous: {1}, Saving model".
                #           format(dataset_loss, min_loss))
                #     min_loss = dataset_loss
                #     torch.save(network.state_dict(), model_save_path)

        torch.save(network.state_dict(), model_save_path)

    @staticmethod
    def eval(eval_parameters, device):
        print(".. Evaluation started ..")
        treated_set = eval_parameters["treated_set"]
        control_set = eval_parameters["control_set"]
        model_path = eval_parameters["model_save_path"]
        network = DCN(training_flag=False).to(device)
        network.load_state_dict(torch.load(model_path, map_location=device))
        network.eval()
        treated_data_loader = torch.utils.data.DataLoader(treated_set,
                                                          shuffle=False, num_workers=1)
        control_data_loader = torch.utils.data.DataLoader(control_set,
                                                          shuffle=False, num_workers=1)

        err_treated_list = []
        err_control_list = []

        for batch in treated_data_loader:
            covariates_X, ps_score, y_f, y_cf = batch
            covariates_X = covariates_X.to(device)
            ps_score = ps_score.squeeze().to(device)
            treatment_pred = network(covariates_X, ps_score)

            predicted_ITE = treatment_pred[0] - treatment_pred[1]
            true_ITE = y_f - y_cf
            if torch.cuda.is_available():
                diff = true_ITE.float().cuda() - predicted_ITE.float().cuda()
            else:
                diff = true_ITE.float() - predicted_ITE.float()

            err_treated_list.append(diff.item())

        for batch in control_data_loader:
            covariates_X, ps_score, y_f, y_cf = batch
            covariates_X = covariates_X.to(device)
            ps_score = ps_score.squeeze().to(device)
            treatment_pred = network(covariates_X, ps_score)

            predicted_ITE = treatment_pred[0] - treatment_pred[1]
            true_ITE = y_cf - y_f
            if torch.cuda.is_available():
                diff = true_ITE.float().cuda() - predicted_ITE.float().cuda()
            else:
                diff = true_ITE.float() - predicted_ITE.float()
            err_control_list.append(diff.item())

        # print(err_treated_list)
        # print(err_control_list)
        return {
            "treated_err": err_treated_list,
            "control_err": err_control_list,
        }

 我们将我们的潜在结果模型称为深度反事实网络(DCN),我们使用首字母缩写DCN- pd来指代具有倾向-退出正则化的DCN。由于我们的模型同时捕捉了倾向得分和结果,因此它是一个双稳健模型(doubly-robust model)。

2.4VCNet

@article{LizhenNie2021VCNetAF,  title={VCNet and Functional Targeted Regularization For Learning Causal Effects of Continuous Treatments},  author={Lizhen Nie and Mao Ye and Qiang Liu and Dan L. Nicolae},  journal={arXiv: Learning},  year={2021}}

参考:

  1. dcn(deep cross network)三部曲 - 知乎
  2. 因果推理实战(1)——借助因果关系从示教中学习任务规则 - 知乎
  3. 通俗解释因果推理 causal inference - 知乎
  4. AB实验的高端玩法系列1 - 走看看
  5. 收藏|浅谈多任务学习(Multi-task Learning) - 知乎
  6. 多任务学习在风控场景的应用探索及案例分享 - 知乎
  7. keras-mmoe/census_income_demo.py at master · drawbridge/keras-mmoe · GitHub
  8. keras-mmoe/census_income_demo.py at master · drawbridge/keras-mmoe · GitHub
  9. 多目标建模(一) - 知乎
  10. 推荐系统(8)—— 多目标优化应用总结_1 - 深度机器学习 - 博客园
  11. 多任务学习在因果建模上应用 - 知乎
  12. 深度学习【22】Mxnet多任务(multi-task)训练_DCD_Lin的博客-CSDN博客_多任务训练
  13. 因果推断在多任务优化场景有什么好的实践? - 知乎
  14. https://huaweicloud.csdn.net/63802f23dacf622b8df8639e.html
  15. ​​​​​​​神经网络训练多任务学习(MTL)时,多个loss怎么分配权重(附代码)_神经网络多任务训练_Ciao112的博客-CSDN博客

问题:

1、多目标训练,X_T_Y(唯一的),特效x [y1,y2],行形式是x [y1,null]?

答:采用参数不更新的方式训练。

3思考

1、使用因果推断给相关性模型纠偏,因果推断和机器学习是什么关系呢?

2、因果推断和机器学习都是解决什么问题?

答:机器学习解决预测问题,不需要知道原因;因果推断知道原因下,预测结果。

3、机器学习解决不了什么问题?

答:不去干预,不去探索原因。

答:回答不了因果问题。

4、因果推断解决不了什么问题?

5、没干预的问题还是因果问题吧?

答:理解不是了。

6、偏差解决方式?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/164557.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一种穷人式的内存泄露检测方式

对于检测程序代码中的资源泄露问题&#xff0c;市面上已经有很多工具了&#xff0c;但是今天我再来介绍一种新的方式&#xff0c;这种方式不需要安装任何工具或者特定的编译器开关&#xff0c;也不需要第三方库。 那就是&#xff1a;一直保持程序运行&#xff0c;直到泄露的原因…

【牛客网】HJ99 自守数、OR86 返回小于 N 的质数个数

作者&#xff1a;一个喜欢猫咪的的程序员 专栏&#xff1a;《Leetcode》 喜欢的话&#xff1a;世间因为少年的挺身而出&#xff0c;而更加瑰丽。 ——《人民日报》 目录 HJ99 自守数 OR86 返回小于 N 的质数个数 HJ99 自守数 自守数_牛客…

Linux系统之安装Linux管理工具inpanel

Linux系统之安装Linux管理工具inpanel一、inpanel介绍1.inpanel简介2.inpanel特点二、检查本地系统环境1.检查系统版本2.检查系统内核版本三、下载inpanel软件包1.创建下载目录2.下载inpanel软件3.查看源码inpanel文件四、部署inpanel应用1.一键安装inpanel2.查看服务端口五、关…

【双U-Net残差网络:超分】

Dual U-Net residual networks for cardiac magnetic resonance images super-resolution &#xff08;心脏磁共振图像超分辨率的双U-Net残差网络&#xff09; 目前&#xff0c;心脏磁共振&#xff08;CMR&#xff09;成像能够提供心脏全方位的结构和功能信息&#xff0c;已成…

难道你也不能放烟花嘛?那就来看看这个吧!

又到了一年一度的春节时期啦&#xff01;昨天呢是北方的小年&#xff0c;今天是南方的小年&#xff0c;看到大家可以愉快的放烟花&#xff0c;过大年很是羡慕呀&#xff01;辞旧岁&#xff0c;贺新春&#xff0c;今年我呀要放烟花&#xff0c;过春节&#xff01;&#x1f9e8;。…

农产品商城简单demo-Android

项目概述 随着科学技术的不断提高和社会经济的不断发展&#xff0c;一些农产品的销售逐渐的落后于社会信息化的潮流之中&#xff0c;尤其是一些年龄较大的中老年人来说是极为不便的&#xff0c;国家大力倡导并十分重视三农问题&#xff0c;倡导推动农村农业的发展&#xff0c;为…

第二章 搜索求解

人工智能中的搜索&#xff1a; 搜索算法的形式化描述&#xff1a;<状态、动作、状态转移、路径、测试目标> 状态&#xff1a;从原问题转化出的问题描述。 动作&#xff1a;从当前时刻所处状态转移到下一时刻所处状态。 状态转移&#xff1a;对某一时刻对应状态进行某一…

泛型的学习

这里写目录标题一、泛型的使用自定义泛型类泛型方法说明泛型在继承方面的体现通配符的使用有限制条件的通配符的的使用每日一考一、泛型的使用 1、jdk5.0新增特性 2、在集合中使用泛型 ①集合接口或集合类在jdk5.0时都修改为带泛型的结构 ②实例化集合时&#xff0c;可以指明具…

是Spring啊!

一.概念spring概念一个包含了众多工具方法的 IoC 容器okk~~分析一下这句话意思,众多方法,IoC 是形容词,容器是名词 -> 众多方法:比如一个类里有许多方法, 容器:存储的东西 重点就是IoC是什么?Ioc2.1解释IoC -> Inversion of Control 控制反转 -> 对象的生命周期 ->…

Git版本控制工具详解

1、版本控制 1.1、认识版本控制&#xff08;版本控制&#xff09; 什么是版本控制&#xff1f; 版本控制的英文是Version control&#xff1b;是维护工程蓝图的标准作法&#xff0c;能追踪工程蓝图从诞生一直到定案的过程&#xff1b;版本控制也是一种软件工程技巧&#xff…

红米 12C earth 秒解锁 跳过168小时 红米note12 note12pro note12pro+系列机型解锁bl root教程步骤Fastboot

最近上手体验了Redmi 12C/红米12C&#xff0c;这是红米新推出的百元机&#xff0c;起售价699元&#xff0c;464G版本&#xff0c;具有不错的性能&#xff0c;具有5000mAh大电池&#xff0c;具有双频wifi&#xff0c;支持双卡双待&#xff0c;支持SD卡扩展等。 如果你近期想要给…

UTF-8和Unicode

文章目录Unicode与网络传输Unicode网络传输UTF&#xff1a;Unicode Transformation Format UTF-8是在网络上传输Unicode的一个转换标准&#xff0c;发送时将字符串Unicode转为UTF-8&#xff0c;接收时将字节转为Unicode&#xff0c;就完成来字符串的传输 Unicode与网络传输 U…

移动端 - 搜索组件(search-list篇)

移动端 - 搜索组件(search-input篇) 移动端 - 搜索组件(suggest篇) 这里我们需要去封装搜索历史组件 这一个组件还是很简单的, 但是逻辑部分需要根据实际的需求来进行书写; 所以这里我不太好去写实际的代码, 不过可以提供我的思路(主要的就是去实现增, 删, 改, 查) 第一步: 首…

【STL】string的常见接口使用

目录 1、string类的基础概念 2、string类的常见接口说明及应用 2.1、string类的成员函数 constructor&#xff08;构造函数&#xff09; destructor&#xff08;析构函数&#xff09; operator&#xff08;赋值&#xff09; string类对象的容量操作 迭代器 string类…

【vue2】组件基础与组件传值(父子组件传值)

&#x1f973;博 主&#xff1a;初映CY的前说(前端领域) &#x1f31e;个人信条&#xff1a;想要变成得到&#xff0c;中间还有做到&#xff01; &#x1f918;本文核心&#xff1a;组件基础概念与全局|局部组件的写法、组件之间传值&#xff08;父传子、子传父&#xff…

rcfile和orcfile

一、数据存储要考虑哪些方面 数据加载时间 Facebook数仓每天存储的数据量超过20TB&#xff0c;数据加载既有磁盘I/O又有网络传输&#xff0c;时间占用大 快速的数据查询 低的空间占用 数据压缩/数据编码 适合多种查询模式 如果所有人都查相同的字段&#xff0c;那么就可以针…

QT添加使用图片与UI资源

QT添加使用图片与UI资源1 QT添加使用图片资源1.1 添加新文件1.2 添加QT - QT Resources File 【UI资源文件】1.3 命名资源包名称 并 添加到项目文件1.4 .pro 文件发生变化 art.qrc1.5 点击qrc文件&#xff0c;添加现有文件 - 添加进去的图片文件可以进行正常引用。1.6 修改样式…

分布式任务处理xxljob

7.1 分布式任务处理 7.1.1 什么是分布式任务调度 视频上传成功需要对视频的格式进行处理&#xff0c;如何用Java程序对视频进行处理呢&#xff1f;这里有一个关键的需求就是当视频比较多的时候我们如何可以高效处理。 如何去高效处理一批任务呢&#xff1f; 1、多线程 多线…

通过Docker启动DB2,并在Spring Boot整合DB2

1 简介 DB2是IBM的一款优秀的关系型数据库&#xff0c;简单学习一下。 2 Docker安装DB2 为了快速启动&#xff0c;直接使用Docker来安装DB2。先下载镜像如下&#xff1a; docker pull ibmcom/db2:11.5.0.0 启动数据库如下&#xff1a; docker run -itd \--name mydb2 \--…

Allegro如何导入和导出Pin Delay操作指导

Allegro如何导入和导出Pin Delay操作指导 在做PCB设计等长设计的时候,Pin Delay是个非常重要的数据,关系到信号的长度,Allegro支持把Pin Delay数据导入到PCB中,并且还支持导出,如下图 具体操作如下 导入Pin Delay,选择File选择Import