因果推断14--DRNet论文和代码学习

news2025/1/22 15:57:43

目录

论文介绍

代码实现

DRNet

ReadMe

因果森林


论文介绍

因果推断3--DRNet(个人笔记)_万三豹的博客-CSDN博客

摘要:估计个体在不同程度的治疗暴露下的潜在反应,对于医疗保健、经济学和公共政策等几个重要领域具有很高的实际意义。然而,现有的从观察数据中估计反事实结果的学习方法要么专注于估计平均剂量-反应曲线,要么局限于只有两种没有相关剂量参数的治疗方法。在这里,我们提出了一种新的机器学习方法,用于学习反事实表示,用于使用神经网络估计具有连续剂量参数的任意数量治疗的单个剂量-反应曲线。在已建立的潜在结果框架的基础上,我们引入了性能指标、模型选择标准、模型架构和用于估计单个剂量反应曲线的开放基准。我们的实验表明,在这项工作中开发的方法在估计个体剂量反应方面设置了一个新的最先进的方法。

代码实现

GitHub - d909b/drnet: 💉📈 Dose response networks (DRNets) are a method for learning to estimate individual dose-response curves for multiple parametric treatments from observational data using neural networks.

DRNet

def get_method_name_map():
        return {
            'knn': KNearestNeighbours,
            'ols1': OrdinaryLeastSquares1,
            'ols2': OrdinaryLeastSquares2,
            'cf': CausalForest,
            'rf': RandomForest,
            'bart': BayesianAdditiveRegressionTrees,
            'nn': TFNeuralNetwork,
            'nn+': NeuralNetwork,
            'xgb': GradientBoostedTrees,
            'gp': GaussianProcess,
            'psm': PSM,
            'psmpbm': PSM_PBM,
            'ganite': GANITE,
            'gps': GPS,
        }

    def _build_graph(self, input_dim, num_units,
                     num_representation_layers, num_regression_layers, weight_initialisation_std,
                     reweight_sample=False, loss_function="l2",
                     imbalance_penalty_function="wass", rbf_sigma=0.1,
                     wass_lambda=10.0, wass_iterations=10, wass_bpt=True):
        """
        Constructs a TensorFlow subgraph for counterfactual regression.
        Sets the following member variables (to TF nodes):
        self.output         The output prediction "y"
        self.tot_loss       The total objective to minimize
        self.imb_loss       The imbalance term of the objective
        self.pred_loss      The prediction term of the objective
        self.weights_in     The input/representation layer weights
        self.weights_out    The output/post-representation layer weights
        self.weights_pred   The (linear) prediction layer weights
        self.h_rep          The layer of the penalized representation
        """
        ''' Initialize input placeholders '''
        self.x = tf.placeholder("float", shape=[None, input_dim], name='x')
        self.t = tf.placeholder("float", shape=[None, 1], name='t')
        self.y_ = tf.placeholder("float", shape=[None, 1], name='y_')

        ''' Parameter placeholders '''
        self.imbalance_loss_weight = tf.placeholder("float", name='r_alpha')
        self.l2_weight = tf.placeholder("float", name='r_lambda')
        self.dropout_representation = tf.placeholder("float", name='dropout_in')
        self.dropout_regression = tf.placeholder("float", name='dropout_out')
        self.p_t = tf.placeholder("float", name='p_treated')

        dim_input = input_dim
        dim_in = num_units
        dim_out = num_units

        weights_in, biases_in = [], []

        if num_representation_layers == 0:
            dim_in = dim_input
        if num_regression_layers == 0:
            dim_out = dim_in

        ''' Construct input/representation layers '''
        h_rep, weights_in, biases_in = build_mlp(self.x, num_representation_layers, dim_in,
                          self.dropout_representation, self.nonlinearity,
                          weight_initialisation_std=weight_initialisation_std)

        # Normalize representation.
        h_rep_norm = h_rep / safe_sqrt(tf.reduce_sum(tf.square(h_rep), axis=1, keep_dims=True))

        ''' Construct ouput layers '''
        y, y_concat, weights_out, weights_pred = self._build_output_graph(h_rep_norm, self.t, dim_in, dim_out,
                                                                          self.dropout_regression,
                                                                          num_regression_layers,
                                                                          weight_initialisation_std)

        ''' Compute sample reweighting '''
        if reweight_sample:
            w_t = self.t/(2*self.p_t)
            w_c = (1-self.t)/(2*(1-self.p_t))
            sample_weight = w_t + w_c
        else:
            sample_weight = 1.0

        self.sample_weight = sample_weight

        ''' Construct factual loss function '''
        if self.with_pehe_loss:
            risk = pred_error = tf.reduce_mean(sample_weight*tf.square(self.y_ - y)) + \
                                pehe_loss(self.y_, y_concat, self.t, self.x, self.num_treatments) / 10.
        elif loss_function == 'log':
            y = 0.995/(1.0+tf.exp(-y)) + 0.0025
            res = self.y_*tf.log(y) + (1.0-self.y_)*tf.log(1.0-y)

            risk = -tf.reduce_mean(sample_weight*res)
            pred_error = -tf.reduce_mean(res)
        else:
            risk = tf.reduce_mean(sample_weight*tf.square(self.y_ - y))
            pred_error = tf.sqrt(tf.reduce_mean(tf.square(self.y_ - y)))

        ''' Regularization '''
        for i in range(0, num_representation_layers):
            self.weight_decay_loss += tf.nn.l2_loss(weights_in[i])

        p_ipm = 0.5

        if self.imbalance_loss_weight_param == 0.0:
            imb_dist = tf.reduce_mean(self.t)
            imb_error = 0
        elif imbalance_penalty_function == 'mmd2_rbf':
            imb_dist = mmd2_rbf(h_rep_norm, self.t, p_ipm, rbf_sigma)
            imb_error = self.imbalance_loss_weight * imb_dist
        elif imbalance_penalty_function == 'mmd2_lin':
            imb_dist = mmd2_lin(h_rep_norm, self.t, p_ipm)
            imb_error = self.imbalance_loss_weight * mmd2_lin(h_rep_norm, self.t, p_ipm)
        elif imbalance_penalty_function == 'mmd_rbf':
            imb_dist = tf.abs(mmd2_rbf(h_rep_norm, self.t, p_ipm, rbf_sigma))
            imb_error = safe_sqrt(tf.square(self.imbalance_loss_weight) * imb_dist)
        elif imbalance_penalty_function == 'mmd_lin':
            imb_dist = mmd2_lin(h_rep_norm, self.t, p_ipm)
            imb_error = safe_sqrt(tf.square(self.imbalance_loss_weight) * imb_dist)
        elif imbalance_penalty_function == 'wass':
            imb_dist, imb_mat = wasserstein(h_rep_norm, self.t, p_ipm, sq=True,
                                            its=wass_iterations, lam=wass_lambda, backpropT=wass_bpt)
            imb_error = self.imbalance_loss_weight * imb_dist
            self.imb_mat = imb_mat  # FOR DEBUG
        elif imbalance_penalty_function == 'wass2':
            imb_dist, imb_mat = wasserstein(h_rep_norm, self.t, p_ipm, sq=True,
                                            its=wass_iterations, lam=wass_lambda, backpropT=wass_bpt)
            imb_error = self.imbalance_loss_weight * imb_dist
            self.imb_mat = imb_mat  # FOR DEBUG
        else:
            imb_dist = lindisc(h_rep_norm, p_ipm, self.t)
            imb_error = self.imbalance_loss_weight * imb_dist

        ''' Total error '''
        tot_error = risk
        if self.imbalance_loss_weight_param != 0.0:
            tot_error = tot_error + imb_error
        tot_error = tot_error + self.l2_weight*self.weight_decay_loss

        self.output = y
        self.tot_loss = tot_error
        self.imb_loss = imb_error
        self.imb_dist = imb_dist
        self.pred_loss = pred_error
        self.weights_in = weights_in
        self.weights_out = weights_out
        self.weights_pred = weights_pred
        self.h_rep = h_rep
        self.h_rep_norm = h_rep_norm

ReadMe

剂量反应网络(DRNets)是一种用于学习使用神经网络从观测数据估计多参数治疗的个体剂量反应曲线的方法。该存储库包含用于评估DRNets的源代码以及用于评估个体治疗效果的最相关的现有最先进方法(有关结果,请参阅我们的手稿)。为了便于将来的研究,源代码被设计为易于使用(1)新方法和(2)新的基准数据集进行扩展。

作者:Patrick Schwab,苏黎世ETHpatrick.schwab@hest.ethz.ch苏黎世联邦理工学院Lorenz Linhardtllorenz@student.ethz.ch,Stefan Bauer,MPI for Intelligent Systemsstefan.bauer@tuebingen.mpg.de,苏黎世联邦理工学院Joachim M.Buhmannjbuhmann@inf.ethz.ch苏黎世联邦理工学院Walter Karlenwalter.karlen@hest.ethz.ch

许可证:MIT,请参阅License.txt

引用

如果您在工作中引用或使用我们的方法、代码或结果,请考虑引用:

@在过程中{schwab2020-剂量反应,

title={{学习用于估计个体剂量响应曲线的反事实表示}},

作者={施瓦布、帕特里克和林哈特、洛伦兹和鲍尔、斯特凡和布曼、约阿希姆·M和卡伦、沃尔特},

booktitle={{AAAI人工智能会议}},

年={2020}

}

用法:

可运行的脚本位于drnet/apps/子目录中。

drnet/apps/main.py是运行实验的主要可运行脚本。

drnet/apps/parameters.py中描述了可运行脚本的可用命令行参数

您可以通过将drnet/models/baseline/baseline.py子类化,将新的基线方法添加到评估中

有关如何实现自己的基线方法的示例,请参见drnet/models/baselines/neural_network.py。

通过向drnet/apps/main.py中的get_method_name_map方法添加新条目,可以从命令行注册新方法以供使用

您可以通过实现基准接口来添加新的基准,有关如何将自己的基准添加到基准套件的示例,请参见drnet/models/benchmarks。

通过向drnet/apps/evaluate.py中的get_benchmark_name_map方法添加新条目,可以从命令行注册新的基准测试以供使用

要求和相关性

该项目设计用于Python 2.7。我们不能保证,也没有测试过与Python 3的兼容性。

要运行TCGA和News基准,需要下载包含这些基准的原始数据样本的SQLite数据库(News.db和TCGA.db)。

您可以使用以下链接下载原始数据:tcga.db和news.db。

请注意,您需要大约10GB的可用磁盘空间来存储数据库。

将数据库文件保存到/数据目录,以便与下面的分步指南兼容或相应地调整命令。

要运行MVICU基准测试,您需要访问MIMIC-III数据库,由于数据集的敏感性,这需要经过审批过程。

注意,您需要大约75GB的可用磁盘空间来存储带有索引的MIMIC-III数据库。

访问数据集并将MIMIC-III数据加载到SQLite数据库(保存为例如/your/path/to/mimic3.db)后,可以使用drnet/apps/load_db_icu.py脚本将MVICU基准数据从MIMIC-IIII数据库提取到中的单独数据库中/数据文件夹,通过运行:

python drnet/apps/load_db_icu.py/your/path/to/mimic3.db./data

一旦建立,基准数据库将使用大约43MB的磁盘空间。

要运行BART、因果森林和GPS,并再现需要安装R的数字。看见https://www.r-project.org/安装说明。

要运行BART,需要安装R包rJava和bartMachine。看见https://github.com/kapelner/bartMachine安装说明。注意,rJava也需要一个工作的Java安装。

要运行因果森林,需要安装R包grf。看见https://github.com/grf-labs/grf安装说明。

要运行GPS,您需要安装R包causaldrf,例如在R-shell中运行install.packages(“causaldrv”)。

要复制论文的数字,您需要安装R-packagelatex2exp。看见https://cran.r-project.org/web/packages/latex2exp/vignettes/using-latex2exp.html安装说明。

有关python依赖关系,请参阅setup.py。您可以使用pipinstall。安装drnet包及其python依赖项。请注意,如果您的系统上没有正常的R安装,rpy2的安装将失败(请参见上文)。

再现实验

确保您具备上面列出的必要要求,包括/与此文件相关的数据目录以及所需的数据库(参见上文)。

您可以使用脚本drnet/apps/run_all_experiments.py获取main.py使用的精确参数,以重现论文中的实验结果。

drnet/apps/run_all_experiments.py脚本打印。

因果森林

https://grf-labs.github.io/grf/articles/grf.h

return self.grf.causal_forest(x,
                                      FloatVector([float(yy) for yy in y]),
                                      FloatVector([float(tt) for tt in t]), seed=909)

W
The treatment assignment (must be a binary or real numeric vector with no NAs).

W

治疗分配(必须是没有NA的二进制或实数矢量)。

class CausalForest(PickleableMixin, Baseline):
    def __init__(self):
        super(CausalForest, self).__init__()
        self.bart = None

    def install_grf(self):
        from rpy2.robjects.packages import importr
        import rpy2.robjects.packages as rpackages
        from rpy2.robjects.vectors import StrVector
        import rpy2.robjects as robjects

        # robjects.r.options(download_file_method='curl')

        # package_names = ["grf"]
        # utils = rpackages.importr('utils')
        # utils.chooseCRANmirror(ind=0)
        # utils.chooseCRANmirror(ind=0)
        #
        # names_to_install = [x for x in package_names if not rpackages.isinstalled(x)]
        # if len(names_to_install) > 0:
        #     utils.install_packages(StrVector(names_to_install))

        return importr("grf")

    def _build(self, **kwargs):
        from rpy2.robjects import numpy2ri
        from sklearn import linear_model
        grf = self.install_grf()

        self.grf = grf
        numpy2ri.activate()
        num_treatments = kwargs["num_treatments"]
        self.with_exposure = kwargs["with_exposure"]

        return [linear_model.Ridge(alpha=.5)] +\
               [None for _ in range(num_treatments)]

    def predict_for_model(self, model, x):
        base_y = Baseline.predict_for_model(self, self.model[0], x)
        if model == self.model[0]:
            return base_y
        else:
            import rpy2.robjects as robjects
            r = robjects.r
            result = r.predict(model, self.preprocess(x))
            y = np.array(result[0])
            return y[:, -1] + base_y

    def fit_grf_model(self, x, t, y):
        from rpy2.robjects.vectors import StrVector, FactorVector, FloatVector, IntVector
        return self.grf.causal_forest(x,
                                      FloatVector([float(yy) for yy in y]),
                                      FloatVector([float(tt) for tt in t]), seed=909)

参考:

  1. 《因果学习周刊》第8期:因果反事实预测 - 知乎
  2. 因果效应估计:用数据和模型指导决策 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/414122.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ERP系统有什么用?主要是这三方面

ERP 是Enterprise Resource Planning 的缩写,即企业资源计划系统,是建立在信息技术基础上,以系统化的管理思想,为企业决策层及员工提供决策运行手段的管理平台。它实现了企业内部资源和企业相关的外部资源的整合。通过软件把企业的人、财、物、产、供、销及相应的物…

socket 到底是个啥

哈喽大家好,我是咸鱼 我相信大家在面试过程中或多或少都会被问到这样一个问题:你能解释一下什么是 socket 吗 我记得我当初的回答很是浅显:socket 也叫套接字,用来负责不同主机程序之间的网络通信连接,socket 的表现…

整柜海运到美国的规格和收费标准是什么

整柜海运是指将所有货物安装在一个整箱内,由发货人和收货人共同操作,而目的港的收货人一般只有一个,方便操作。整柜海运到美国的主要流程有以下几个步骤:订舱、装柜、报关、海运、清关、提柜和送货。实际上,国际物流出…

IDEA 2023.1 最新变化

IntelliJ IDEA 2023.1 最新变化 在 IntelliJ IDEA 2023.1 中,我们根据用户的宝贵反馈对新 UI 做出了大量改进。 我们还实现了性能增强,从而更快导入 Maven,以及在打开项目时更早提供 IDE 功能。 新版本通过后台提交检查提供了简化的提交流程…

写博客8年与人生第一个502万

题记:我们并非生来强大,但依然可以不负青春。 原本想好好写一下如何制定一个目标并通过一点一滴的努力去实现,这三年反思发现其实写自己的经历并不重要。 很多人都听过一句话:榜样的力量是无穷的。 更现实和实际的情况是&#x…

【LeetCode】每日一题:移除元素

目录 题目: 思想1:暴力解法 思想2:创建一个temp数组 思想3:双指针 👻内容专栏:《LeetCode刷题专栏》 🐨本文概括:27.移除元素 🐼本文作者:花 碟 &#x1…

电商评论数据爬取--R语言

1.网络爬虫 1.1 什么是网络爬虫 网络爬虫(web crawler),也被称为网络蜘蛛(web spider),是在万维网浏览网页并按照一定规则提取信息的脚本或者程序。 浏览网页时,一般流程如下: 利…

C++语法(17)---- 二叉搜索树

1.概念 1.父节点的左子树全小于本身 2.父节点的右子树全大于本身 3.左右子树也是二叉搜索树 时间复杂度:O(N),有可能只有左数,这样就遍历了所有,所有复杂度为N 平衡二叉树的时间复杂度才是:O(logN) 2.模拟 1.数据元素…

Postman抓包教程

目录 什么是抓包? 如何使用 Postman 进行抓包 查看历史抓包数据 使用抓包数据进行接口测试和开发 抓包技巧和注意事项 什么是抓包? 在计算机网络中,抓包是指捕获网络流量的过程。抓包工具可以截获进出计算机网络的数据流,并将…

反向传播推导+numpy实现

很久没有看深度学习了,忘了好多东西。本来想着推导一下,后来发现自己不会了。 再看看以前写的代码,又避开了最终的东西,于是决定重新推导一下。 数据的说明 首先,我们要做一个回归的任务,我们使用numpy随…

5.Java循环控制语句

Java循环控制语句 循环是Java中应用最为广泛的一个知识点,所以也是很需要掌握的。所谓循环,即通过判断条件,重复执行一段代码,根据条件的变化,来确定代码是否执行,执行次数。 一、循环结构 1、while循环…

Java IO常用操作详解(代码示例)

概览 Java I/O操作指的是数据的输入/输出操作。 Java的I/O操作类在java.io包中,主要分以下几种: 基于字节操作的I/O接口: InputStream和OutputStream基于字符操作的I/O接口: Writer和Reader基于磁盘操作的I/O接口: …

5个令人惊艳的AI项目,开源了。。

大家好,我是 Jack。 今天清明,小伙伴们都去哪里玩了? 上个月我已经出去浪过了,清明就老实在家歇着了。 翻看了一些最近热点的开源项目,发现还是 AIGC 的天下。 今天,我将继续着重挑选几个近期的优质开源…

2023第十四届蓝桥杯C++B组菜鸡的落幕

时隔几天,终于还是忍不住来复盘一下蓝桥杯了,还记得去年参加做下填空,再做对个把编程,后面不会的大题打打表混混分,最后就能混个省奖, 这回估计凉透了,填空没对似乎,编程也没对几个,…

Kettle8.2.0连接Hive3.1.2(踩坑,亲测有效)

这是目前遇到的最简单但最头疼的安装,因为是在公司之前用过的服务器上进行安装测试,加上又使用比较新的版本,结果踩了不少坑。Kettle连接Hive这个坑,从2023年4月11日下午开始,一致到2023年4月12日中午才弄好&#xff0…

uni-app常用配置

保存自动格式化 工具》设置》编辑器设置》保存时自动格式化 JS语法检查 安装eslint-js插件eslint-js - DCloud 插件市场 用于校验js和html中的js代码https://ext.dcloud.net.cn/plugin?id2037工具》设置》插件配置》eslint-js 启用实时校检 Vue语法检查 安装eslint-vue插…

【星界探索——通信卫星】铱星:从“星光坠落”到“涅槃重生”,万字长文分析铱星卫星系统市场

【星界探索——通信卫星】铱星:从“星光坠落”到“涅槃重生”一、铱星简介二、铱星系统设计思路2.1 工作原理2.2 铱星布局三、铱星优势四、发展历程五、第一代铱星公司的破产原因分析5.1 终端和资费价格高昂,市场用户群体小5.2 财务危机5.3 市场分析不足…

一文吃透低代码平台源代码交付的重要性(避坑指南)

一、前言 作为这两年IT界的风口,低代码在众人眼里已经不是什么陌生的概念。 对标于传统的纯代码开发,低代码是一种快速开发软件(应用程序)的方法,平台通过对大量功能与场景做提前封装,使得用户可以在可视…

MySQL开发05-MySQL开发规范

文章目录1、命名规范2、表设计规范3、索引规范4、SQL语句规范5、SQL脚本规范6、数据架构规范7、配置文件建议8、其他规范9、总结1、命名规范 命名应有意义,包括库名、表名、用户名等,以使用方便记忆、描述性强的可读性名称为第一准则,尽量避…

Docker网络案例

bridge 是什么 Docker 服务默认会创建一个 docker0 网桥(其上有一个 docker0 内部接口),该桥接网络的名称为docker0,它在内核层连通了其他的物理或虚拟网卡,这就将所有容器和本地主机都放到同一个物理网络。Docker 默认指定了 docker0 接口 的 IP 地址和子网掩码,让主机…