b

DiscoverSearch
About
My stuff
Learning Overlapping Representations for the Estimation of Individualized Treatment Effects
2020·arXiv
Abstract
Abstract

The choice of making an intervention depends on its potential benefit or harm in comparison to alternatives. Estimating the likely outcome of alternatives from observational data is a challenging problem as all outcomes are never observed, and selection bias precludes the direct comparison of differently intervened groups. Despite their empirical success, we show that algorithms that learn domain-invariant representations of inputs (on which to make predictions) are often inappropriate, and develop generalization bounds that demonstrate the dependence on domain overlap and highlight the need for invertible latent maps. Based on these results, we develop a deep kernel regression algorithm and posterior regularization framework that substantially outperforms the state-of-the-art on a variety of benchmarks data sets.

Counterfactual estimation poses the question of what would have been the outcome, if a different intervention had been applied. In order to make decisions in complex domains, making predictions on the causal effects of different actions and how these may vary across individuals is critical. In this paper we focus on the problem of making these predictions based on observational data, which is increasingly available in many domains such as medicine, public policy and advertising. In this setting, past actions, outcomes and context are available, but not knowledge of the treatment assignment policy - we do not know why a given individual was intervened

Proceedings of the 23rdInternational Conference on Artificial Intelligence and Statistics (AISTATS) 2020, Palermo, Italy. PMLR: Volume 108. Copyright 2020 by the author(s).

on or not. The treatment assignment mechanism will often be causally affected by context variables that also causally influence the outcome. As an example, motivated unemployed individuals are more likely to both take advantage of government training programs and find a new job soon.

Learning from observational data requires adjusting for the covariate shift that exists between groups of individuals that are observed to have received differ-ent interventions. The challenge is how to untangle confounding factors and make valid counterfactual predictions - the response if a different treatment had been applied. Recent methods have predominantly focused on learning representations regularized to balance these confounding factors by enforcing domain invariance with distributional distances [11, 12, 20]. In this paper we argue that domain invariance is often too strict a requirement; overlapping support is sufficient for identifiability of the causal effect and equality in densities is not necessary. We interpret the loss in predictive power of domain-invariant representations as loss of information in the input variables that causally influence the treatment assignment, which is also often highly predictive of the treatment effect. Consider the example above for illustration: it is precisely because motivation is predictive of job outcomes that it confounds the treatment effect.

We introduce an optimization framework based on regularizing posterior distributions of the treatment effect that includes existing representation learning algorithms for different choices of regularization terms. We take advantage of this framework to introduce a novel type of regularization criteria for the problem of treatment effect estimation: the posterior counterfactual variance for enforcing domain overlap, and invertible representations to preserve the information content of the underlying context. Such an objective enjoys much better generalization in small sample regimes, smoother representation surfaces with respect to the outcomes, as can be seen in Figure 1, and a Bayesian treatment of paramaters which allows consistent uncer-

image

Figure 1: T-SNE visualizations of the learned embeddings for the control potential outcomes Y (0) of the IHDP dataset. Each panel shows representations regularized by different criteria and the colored heatmap represents different outcome magnitudes with different colors. The left panel shows representations regularized by the Wasserstein distributional distance and results in poor discrimination. The middle panel shows representations optimized only for the factual data with the Gaussian likelihood. The right panel shows representations regularized by the counterfactual variance, our proposed criteria. Much better separation in outcomes is obtained by regularizing for the predictive variance, in contrast to using integral probability metrics such as the Wasserstein distance. Similar plots for a comparison between the control and treated group, and the treated potential outcome Y (1) can be found in Appendix 5.3.

tainty estimation in predictions.

We summarize our contributions as follows:

1. We develop a theory that justifies regularizing for the posterior variance to improve generalization error and establish the limitations of distributional distances.

2. We propose to use deep kernels and posterior regularization as a general framework to learn individualized treatment effects for arbitrary regularization terms.

3. We provide an instantiation of this algorithm informed by novel generalization bounds that substantially improves upon the performance of state-of-the-art prediction algorithms.

1.1 Related work

Due to the ability of deep neural nets to learn rich representations, recent advances in predicting individualized treatment effects have focused on learning representations invariant to the treatment assignment policy that achieve a small error on the factual data. The hope is that the learnt representation and prediction function can generalize to prediction of counterfactual outcomes. Several methods follow this approach. [11] proposed learning a representation of the data that makes the treated and control distributions more similar, fitting a linear ridge-regression model on top of it. [17] built on their approach to derive a more flexible family of algorithms including non-linear hypotheses. However, both algorithms insist on quantifying divergence between treated and control groups with integral probability metrics.

In this work, we share the need for good representations but argue for enforcing support overlap rather than equality in densities. [20], inspired by nearestneighbour methods, learn representation that preserve local similarity information in feature space and were able to show a decrease in the generalization error in counterfactual estimation.

Adapting Bayesian algorithms for the problem of individualized treatment effects has attracted a lot of interest, in particular in the field of medicine where quantifying uncertainty is important. [1] regularize counterfactual predictions through their posterior variance and similarly stress the importance to provide confidence in their estimates using credible intervals, but did not investigate the generalization properties of their algorithm and only allowed for limited expressiveness in their algorithm. Similarly, [10] used posterior variance regularization to learn from unlabeled data in situations where labeled data is scarce for improved performance.

Our work has also strong connections with work on domain adaptation. In particular, estimating ITE requires predictions of outcomes over a different distribution from the observed one. Our ITE error upper bound has similarities with generalization bounds in domain adaptation given by [11, 12]. [12] and [21] have similarly argued against enforcing domain-invariance, and related the loss of predictive power of those representations to the loss of information due to the non-invertibility of learned representations.

To estimate what would happen under different interventions, we first frame the causal problem as a statistical one. We build upon the potential outcomes framework [16], this gives us a formal mathematical representation of the potential outcome using the notation Y (a), which reads as observing Y if we do a.

Set up. Consider a population of individuals with each individual i described by a context  Xi ∈ X(typically Rd). An intervention T is applied to subjects in the population. An important special case is when interventions are binary,  T ∈ {0, 1}: individual i’s response to the intervention is a random variable denoted by Yi(1), whereas i’s natural response when no intervention is applied is denoted by  Yi(0). The two random variables,  Yi(1), Yi(0) ∈ R, are known as the potential outcomes. As discussed, in observational data intervention assignments generally depend on the subjects’ features, i.e.  Ti ̸⊥⊥ Xi. This dependence is quantified via the conditional distribution  p(Ti|Xi), also known as the propensity score of subject i.

Assumption 1 (Consistency, ignorability and overlap). For any individual i, assigned to intervention  ti, we observe  Yi = Y (ti). Further,  {Y (t)}t∈Tand the datagenerating process p(X, T, Y ) satisfy strong ignorability: Y (0), Y (1) T|X and overlap:  ∀x, 0 < p(T|X) < 1.

Assumption 1 is a sufficient condition for causal identifi-ability [16]. Ignorability is also known as the no hidden confounders assumption, indicating that all variables that cause both T and Y are assumed to be measured. Under ignorability therefore, any domain shift in p(X) cannot be due to variables that causally influence T and Y , other than through X. Under Assumption 1, potential outcomes equal conditional expectations: E[Y (t)|X = x] = E[Y |X = x, T = t], and we may predict Y (t) by regression.

Objective. We attempt to learn predictors  ft : X →Y such that  ft(x)approximates E[Y |X = x, T = t]. Predictors  ft = wt ◦ φare formed by composing a function  wt : Z → Ythat operates in a feature space Z defined by a representation  φ : X → Z.

The individual effect of an intervention T = 1 in context X is measured by the conditional average treatment effect (CATE),  τ(x) = E[Y (1) − Y (0)|X = x], and theerror in estimation is given by the empirical precision in estimating heterogeneous effects  ϵPEHE, defined as the mean squared error in the estimation of the treatment

effect  τ(x),

image

Predicting  τfor unobserved units involves prediction of both potential outcomes, but since we only observe the "factual" outcome for a specific treatment assignment, and never observe the corresponding "counterfactual" outcome, we never observe any samples of the true treatment effect in an observational dataset. This makes the problem of causal inference fundamentally different from standard supervised learning.

Notation. We analyse the generalization properties of predictors  ftand underlying representations φwith respect to  ϵPEHEin the next section. To do so we will make use of the following notation. Let  Px,tdenote the input data distribution p(x, t), Pt = p(yt|x)p(x, t)the joint factual distribution of x and  yt, P1−t = p(yt|x)p(x, 1 − t)the joint counterfactual distribution of x and  yt. Each instance in the observed (factual) dataset  Dt = {(xi,t, yi,t)}Nti=1, is as- sumed to be sampled  i.i.d from Pt. Dt−1will be used to denote the unobserved (counterfactual) data set that results from flipping the treatment assignment for each instance. While we assume  φto be deterministic, in the following section we let  wtbe a vector of weights with prior distribution  πt = N(0, λ−1t I), λt > 0. In this sense,  πtdefines the hypothesis space F of  ftand we write  ˆρtfor the posterior distribution of  ft, itself a random variable. We write  µ(xi|Dt, Θt) and σ2(xi|Dt, Θt)for its posterior mean and variance given context  xi. Θtincludes all hyperparameters (both shared parameters (in  φ) and specific parameters to each treatment group).

Inherent to the approach of learning representations for counterfactual inference is that the representation must trade-off between containing predictive information about factual outcomes while mitigating the information content that drives the treatment selection policy to ensure good generalization on counterfactuals. In this section, we make several observations about the deficiencies of enforcing domain invariance for this purpose and propose alternatives based on the posterior counterfactual variance. We start with a simple example that illustrates the inadequacy of distributional distances and the benefits of counterfactual variance minimization.

image

butional distances. In the middle panels of Figure 2 we show two simulated datasets. The left-hand/red dataset arises from truncated normal distributions with

image

Figure 2: Toy example illustrating the shortcomings of distributional distances, like IPMs, for regularizing representations in causal inference. Despite the fact that sufficient support is satisfied in the red population and not in the green population, IPMs (bottom) give the opposite result, with larger discrepancy between groups in the red population than in the green population. In contrast, the counterfactual variance (top) accurately describes the lack of support of the green population.

large overlap in the tails; the right-hand/green dataset arises from ordinary normal distributions with small overlap in the tails. In both cases, we show treated and control outcomes in different shades. In both cases, the outcome is y = sinc(4x). The red population satisfies sufficient assumptions for identifiability of causal effects; the green population does not. However, as shown in the bottom panels, both the MMD and Wasserstein distances are smaller in the green population than in the red population. In contrast, the top panel shows that the predictive variance of the counterfactual outcomes much better describes this lack of overlap. Counterfactual variance is adaptive to the prediction problem of interest (given that it is obtained from a fitted model), providing a data-dependent measure to quantify distances in the underlying function class, perhaps more precise when the underlying function to be estimated are unknown. IPMs are defined as worst-case distances dependent on a function class to be specified a priori. We make the observation also that IPMs need to be approximated in practice which may be inaccurate for high-dimensional representations and small training data samples [2].

3.1 Generalization bounds

In this subsection, we develop a PAC-Bayes generalization bound [14] for the empirical precision in estimating heterogeneous effects  ϵPEHE, that shows specifically why minimizing counterfactual variance can improve generalization performance.

Theorem 1. With the assumption that the squared loss function  L : F ×X ×Y → Ris sub-gaussian under πtand  Pt, and using the notation introduced above, the following holds. With probability at least  1 − δand for any posterior distribution  ˆρt on F, ϵPEHEis upper-bounded by

image

Ct,1(Nt, N1−t), Ct,2(Nt, N1−t)are linear function in their arguments,  Ct,3is constant,  D∞(Px,1−t∥Px,t) =supxp(x,1−t)p(x,t),  Lˆρt(X, Y) = 1N�Ni=1�µ(xi|Dt, Θt) −yi�2is the posterior prediction loss, Varˆρt(X)is the posterior variance on X, Varˆρt(X) =

image

and finally  KL(·∥·)is the Kullback-Leibler divergence.

Observe the dependence on the variance with respect to the distribution of the observed data of the counterfactual outcomes  Varˆρt(X1−t). By minimizing the counterfactual variance we can expect the representations  φof counterfactual data to encode a relatively smoother prediction function  w ◦ φ(x), as can be seen in Figure 1. The estimation of treatment effects is inherently a label-scarce problem as counterfactual data is not observed, representations resulting in smooth prediction curves are especially important to generalize beyond the factual data. This view of the problem of estimating treatment effects emphasizes the need of regularization for good generalization. We note that the sub-gaussian assumption on the loss function may not hold in some cases e.g. large noise. In this case, we may relax this assumption to sub-exponential with all results and proofs otherwise unchanged [8].

3.1.1 Why distributional distances may be inadequate?

While the toy example clearly illustrates the inability of distributional distances to capture domain overlap, we argue that by enforcing equality in full marginals, optimizing for distributional distances may also overly penalize the model’s ability to predict factual observations when sufficient data is available, that is, when domain overlap is satisfied.

In the following Theorem, we provide a bound on the generalization error on counterfactual data that illustrates the interplay between distribution mismatch and prediction loss on factual data. We show that distribution mismatch becomes decreasingly relevant with increasing data set size.

Theorem 2. Assume the notation introduced above, for any posterior distribution  ˆρt on F, the expected counterfactual Gibbs risk1 RP1−t(Gˆρt)is bounded above by,

image

where DPx,1−t(ˆρt) = Ex∼Px,1−t�σ2(x|Dt, Θt)�is the expected counterfactual variance, LPt(ˆρt) =E(x,y)∼Pt��µ(x|Dt, Θt) − y�2 + 12σ2(x|Dt, Θt)�is the expected factual loss, and finally  D∞(Px,1−t∥Px,t) =supxp(x,1−t)p(x,t) .

The bound in Equation (3) describes the interaction between the distribution mismatch and the prediction error on factual data  LPt(Gˆρt). D∞(Px,1−t∥Px,t)is large if, for some  x, Px,thas small density while  Px,1−thas large density (that is when there is poor overlap between them), which understandably, makes minimizing the counterfactual risk  RP1−t(Gˆρt)harder because few examples from the other population are observed for a given context. However, note that  D∞(Px,1−t∥Px,t)is multiplied by the expected factual loss  LPt(Gˆρt)(which decreases as the number of factual samples increases, see Equation (10) in Theorem 7 of Appendix 1.3). The distribution mismatch  D∞(Px,1−t∥Px,t) thusbecomes less important for generalization, if we can minimize the expected factual loss  LPt(Gˆρt)arbitrarily well. This suggests that optimizing for distributional distances between treated and control data at the expense of prediction error on the factual data may be counterproductive. Representations regularized with distributional distances may thus shrink the hypothesis space and converge on solutions that, although balanced across treatment groups, loose their predictive power.

Here again, we note the presence of the counterfactual variance  LPt(ˆρt)in equation (3) which is not considered by methods optimizing for the distributional difference only.

3.2 Why encourage preserving information content?

Assumption 1 gives sufficient conditions for unbiased estimation of the treatment effect using observational data, but identifiability need not hold with respect to the feature representation  Z = φ(X), even if it does with respect to X. For instance consider  φ−1(z) :=

image

in general, with equality only if  φis invertible. The conditional independence in the strong ignorability assumption,  Y (0), Y (1) T|φ(X), required for estimating the treatment effect need not hold for non-invertible transformations. In this sense, we may be introducing unobserved confounders in representation space we hypothesize by the information lost in the map  φ. Observe also that our objective, the conditional average treatment effect  τ(x) = E[Y (1)−Y (0)|X = x], is expressed in terms of expectations. Similarly, it holds that in feature space,

image

will in general not be equal to our quantity of interest τ(x) = E[Y (1) − Y (0)|x] unless φis invertible.

In this section, we describe a model for counterfactual estimation, called DKLITE2 (Deep Kernel Learning for Individualized Treatment Effects), motivated by our analysis. Our proposed prediction algorithm works in a feature space induced by  φindirectly via a kernel function  K : X × X → R, that intuitively models the correlation between inputs  x, x′ ∈ Xin a possibly high-dimensional feature space. We extend the expressiveness of K by transforming the inputs through a non-linear mapping  φ : X → Rdφto form k(φ(x), φ(x′)), a deep kernel [19]. The dimension  dφ ofφ(x)can be chosen arbitrarily, and may therefore differ from the dimension of x. We let  φθbe parameterized by a neural network with parameters  θto encode a the information content of input variables.

Prediction model. The potential outcomes given individuals covariates are assumed to take the form,

image

where  ft(xi) = w⊤t φ(xi)is given as a linear combination of feature representations  φ(·). wt ∈ Rd is a weight vector,  φ(·)the representation layer and  ϵi,t ∼ N(0, β−1t )is a noise variable, t = 0, 1. We specify our uncertainty in parameter values through the specification of prior distributions. Combined with the likelihood of observed data, these determine the posterior distribution of parameters and predictive distributions of treatment effects. For each component of the weight vector, independently, we assume  wt ∼ N(0, λ−1t ). With knowl- edge of  φ, the posterior of  wtis given by,

image

where  Θt = {φ, βt, λt}, the posterior mean  mwt =βtK−1wt Φ⊤t Yt, the posterior covariance K−1wt =βtΦ⊤t Φt + λtIdφ×dφand  Φtthe representation of  Xt. However, it can be difficult to encode our knowledge about good representation spaces for counterfactual estimation in a Bayesian prior. Posterior regularization offers a more direct and flexible mechanism for guiding and controlling the posterior distribution.

Regularized Bayes. The regularized posterior is the solution of the following optimization problem [22, 10]:

image

where L(q(f|D)) is the (KL-divergence between the approximate posterior q(f|D) and the true posterior p(f|D), and  Ω(q(f|D))is a regularizer of the approximate posterior q(f|D). We attempt therefore to learn a posterior distribution as close as possible to the true posterior while also fulfilling the requirements imposed by the regularization term.

Inference. The predictive distribution for a testing point x is obtained by marginalising over the posterior wt. It is given by

image

where  µ(x|Dt, Θt) = m⊤wtφ(x)and  σ2(x|Dt, Θt) =φ(x)⊤K−1wt φ(x). Point estimates are given for example by the posterior mean or median, optimal for minimizing squared loss or absolute loss, respectively. Note also that knowledge of the full posterior allows us to quantify our uncertainty around point estimates through credible intervals, especially useful in medicine and public policy, for example.

4.1 Learning φ

The framework of regularized Bayes described in (7) provides us with a flexible optimization objective, many previous algorithms for treatment effects can be reduced to specific instantiations of this problem. In this section we describe the likelihood and regularization terms that we use to encode the intuition and insights derived from section 3. In Appendix 2, we show that the final loss function we derive can be reformulated a specific instance of problem (7).

Likelihood of the factual data. Assuming the discriminative model introduced in 5, we encourage good prediction of the factual data by optimizing for the negative log-likelihood of the factual data,

image

Observe that we may re-write the negative log-likelihood,  Llik, as

image

From Theorem 1, we see that the empirically estimable quantities to optimize for in the upperbound (2), are Varˆρt(X1−t), Varˆρt(Xt), Lˆρt(Xt, Yt)and  KL(ˆρt∥πt). The latter three already exist in our objective  Llik, asshown in Equation (10).

Posterior variance as regularization. Therefore, by including an empirical estimate of the counterfactual variance  Varˆρt(X1−t)as a regularizer in our objective we are effectively optimizing for all terms given by the PAC-Bayes upperbound in equation (2). We write this regularization term as,

image

Since the deep kernel parameters are jointly learned, the neural net  φis encouraged to learn a feature representation in which the counterfactual examples are close to the factual examples, thereby reducing the variance on our predictions. The implication is that we are optimizing for representations where counterfactual data tend to cluster around the representations of factual data. This is a way to see how intuitively that we are encouraging overlap in support in representation space without enforcing equality in densities (i.e. the size of the factual and counterfactual "clusters" need not coincide).

Invertibility as regularization. While the loss due to non-invertible representations is not directly observable, we may associate it with the information content of x lost in  φ(x)and we found it to be an important source of gain in performance, empirically. We can encourage information content preservation with an

image

Table 1: Mean performance (lower better) of individualized treatment effect estimation and standard deviation.

additional decoder  ψ : Rdφ → X- a neural network with the reversed structure of network  φ(the encoder in this sense) - trained to reconstruct the input x from φ(x). The reconstruction loss is given as,

image

We remark that some recent density estimation approaches, e.g. normalization flows [15] and non-volume preserving transformation [7], can be used as an alternative method to achieve invertibility.

Final loss function. Based on the objectives described above, our final loss trades-off between maximizing the likelihood of the observed (factual) data under our model, minimizing the predictive variance of the counterfactual outcomes and minimizing the reconstruction loss of the representations. The loss is given as follows,

image

α1 > 0and  α2 > 0are hyperparameter. Standard methods for hyperparameter selection are not generally applicable for choosing hyperparameters because counterfactuals are never observed. As an approximation scheme, we replace the missing counterfactuals with their nearest factual neighbour to compute the treatment effects in cross-validation [17].

The primary focus of our experiments will be on the comparison of prediction performance on benchmark tasks, the use of the posterior variance for decisionmaking and, a deeper analysis of learned representations and source of performance gain. In the following, we start by introducing competing prediction algorithms before giving a brief description of the data and appropriate performance metrics.

Competing algorithms. We compare DKLITE with a total of 11 algorithms. First we evaluate least squares regression using treatment as an additional input feature  OLS/LR1), we consider separate least squares regressions for each treatment (OLS/LR2), we evaluate balancing linear regression (BLR) [11], k-nearest neighbor (k-NN) [5], Bayesian additive regression trees (BART) [4], random forests (R-Forest) [3], causal forests (C-Forest) [18], balancing neural networks (BNN) [11], treatment-agnostic representation network (TARNET), counterfactual regression with Wasserstein distance  (CARWASS) [17], and multi-task gaussian process (CMGP) [1].

Data. Causal inference models are often impossible to reliably validate using real-world data due to the absence of counterfactual outcomes. Various established approaches for evaluating causal models have been proposed, which we use for our analysis. We describe these briefly below and refer the reader to the accompanying references and section 5 of Appendix for further details. Additional results (e.g. average treatment effect estimation) are included in the same Appendix section. We consider IHDP ( 747 instances described by 25 covariates) [9, 1, 17, 20] in which counterfactual outcomes are randomly generated via a predefined probabilistic model; Twins (11300 instances described by 30 covariates), in which outcomes are observed but the treatment assignment is simulated; and finally, Jobs (3212 instances described by 7 covariates) [17, 13, 6] which is constructed from a mixture of experimental and observational data with the treatment.

Metrics. The metrics used to evaluate each data set differ slightly depending on the available outcome (real or simulated) and the available treatment assignment mechanisms (known or unknown). For the IHDP, we use the empirical precision in estimating treatment effects  ˆϵPEHE = 1N�Ni=1(τ(xi)−ˆτ(xi))2. For Twins, weuse the observed precision in estimating heterogeneous effects,  ˜ϵPEHE = 1N�Ni=1(yi,1−yi,0−ˆτ(xi))2. For Jobs, we use the policy risk that measures the expected loss if the treatment is taken according to the ITE policy prescribed by the algorithm,  Rpol = 1−E[Y (1)|π(x) =1]P(π(x) = 1) + E[Y (0)|π(x) = 0]P(π(x) = 0)where π(x) = 1 if ˆy(1) − ˆy(0) > 0 and π(x) = 0, otherwise.

5.1 Prediction performance results

We report in-sample and out-of-sample performance in Table 1. DKLITE targets aspects of the treatment effect estimation problem that have not been considered before. Learning with such an objective significantly outperforms all competing algorithms and does so on all benchmark data sets. The most relevant comparison is perhaps with BNN [11] and  CARWASS[17], neural network models that enforces domain invariance through distributional distances. The performance gain highlights the predictive power of our representations.

5.2 Source of gain

image

Table 2: Source of performance gain.

In this section, we analyze more deeply the contribution to the performance gain of each component of our loss. We evaluate DKLITE optimized for different components of  Lfin = Llik + α1Lvar + α2Lrec. As can be seen in Table 2, including regularization based on the counterfactual variance (Llik +α1Lvar) and reconstruction loss (Llik + α2Lrec), each evaluated separately, already provides a significant gain in performance with respect to optimization on the factual data only (Llik).Importantly though, combining them (Lfin) improvesperformance further by an order of magnitude (see DKLITE in Table 1), which suggests that  Lvarand Lreccapture to some extent "orthogonal" sources of gain. The gain is especially important on relatively smaller data sets, such as IHDP with 747 individuals, and to a lesser extent on bigger data sets. These results illustrate the behaviour suggested by Equation 3 in Theorem 2: distribution mismatch between groups becomes decreasingly relevant with increasing data set size. In this setting, the error on the factual data (Llik)drives generalization performance.

5.3 Leveraging the predicted uncertainty

Data-driven solutions for decision support have most often been proposed without methods to quantify and control their uncertainty in a decision. In contrast, in medicine for example, a physician knows whether she is uncertain about a case and will consult more experienced colleagues if needed. We use this idea to show that uncertainty informed treatment effect estimation can improve performance. An instantiation of this approach, termed DKLITE-U, is given by referring the 10 % most uncertain predictions for further scrutiny. Performance in comparison to DKLITE is given in Table 3. Note that especially on small data sets, such as IHDP, predictions can be significantly improved.

image

Table 3: Performance of DKLITE-U on all data sets.

In many domains understanding the effect of interventions at an individual level is crucial, but prediction of those potential outcomes is challenging. Despite their empirical success, we find that algorithms enforcing representations to satisfy domain invariance is often too strong a requirement for causal predictions. This stems from the fact that overlapping support is suffi-cient for identifiability of the causal effect and equality in densities is not necessary.

We have proposed a bound on generalization performance that shows the dependence on domain overlap through the counterfactual variance which we interpret as a proxy for domain overlap, and highlighted the need for invertible latent maps. These results motivated novel regularization criteria that we incorporated in a deep kernel learning framework through posterior regularization. We hypothesize that many existing models, both frequentist (through representer theorems if optimized with empirical risk minimization) and Bayesian, can be reduced to specific instances of our framework with different regularization terms. We leave a formal derivation of this unifying theory for future work.

We thank the anonymous reviewers for valuable feedback. This work was supported by GlaxoSmithKline (GSK), the Alan Turing Institute under the EPSRC grant EP/N510129/1, the US Office of Naval Research (ONR), and the National Science Foundation (NSF): grant numbers ECCS1462245, ECCS1533983, and ECCS1407712.

[1] A. M. Alaa and M. van der Schaar. Bayesian inference of individualized treatment effects using multi-task gaussian processes. In Advances in Neural Information Processing Systems, pages 3424–3432, 2017.

[2] E. Boissard and T. Le Gouic. On the mean speed of convergence of empirical and occupation measures in wasserstein distance. In  Annales de l’IHPProbabilités et statistiques, volume 50, pages 539– 563, 2014.

[3] L. Breiman. Random forests. Machine learning, 45(1):5–32, 2001.

[4] H. A. Chipman, E. I. George, R. E. McCulloch, et al. Bart: Bayesian additive regression trees. The Annals of Applied Statistics, 4(1):266–298, 2010.

[5] R. K. Crump, V. J. Hotz, G. W. Imbens, and O. A. Mitnik. Nonparametric tests for treatment effect heterogeneity. The Review of Economics and Statistics, 90(3):389–405, 2008.

[6] R. H. Dehejia and S. Wahba. Propensity scorematching methods for nonexperimental causal studies. Review of Economics and statistics, 84(1):151–161, 2002.

[7] L. Dinh, J. Sohl-Dickstein, and S. Bengio. Density estimation using real nvp. arXiv preprint arXiv:1605.08803, 2016.

[8] P. Germain, F. Bach, A. Lacoste, and S. LacosteJulien. Pac-bayesian theory meets bayesian inference. In Advances in Neural Information Processing Systems, pages 1884–1892, 2016.

[9] J. L. Hill. Bayesian nonparametric modeling for causal inference. Journal of Computational and Graphical Statistics, 20(1):217–240, 2011.

[10] N. Jean, S. M. Xie, and S. Ermon. Semi-supervised deep kernel learning: Regression with unlabeled data by minimizing predictive variance. In Advances in Neural Information Processing Systems, pages 5322–5333, 2018.

[11] F. Johansson, U. Shalit, and D. Sontag. Learning representations for counterfactual inference. In International conference on machine learning, pages 3020–3029, 2016.

[12] F. D. Johansson, N. Kallus, U. Shalit, and D. Sontag. Learning weighted representations for generalization across designs. arXiv preprint arXiv:1802.08598, 2018.

[13] R. J. LaLonde. Evaluating the econometric evaluations of training programs with experimental data. The American economic review, pages 604–620, 1986.

[14] D. A. McAllester. Some pac-bayesian theorems. Machine Learning, 37(3):355–363, 1999.

[15] D. J. Rezende and S. Mohamed. Variational inference with normalizing flows. arXiv preprint arXiv:1505.05770, 2015.

[16] D. B. Rubin. Causal inference using potential outcomes: Design, modeling, decisions. Journal of the American Statistical Association, 100(469):322– 331, 2005.

[17] U. Shalit, F. D. Johansson, and D. Sontag. Estimating individual treatment effect: generalization bounds and algorithms. In Proceedings of the 34th International Conference on Machine LearningVolume 70, pages 3076–3085. JMLR. org, 2017.

[18] S. Wager and S. Athey. Estimation and inference of heterogeneous treatment effects using random forests. Journal of the American Statistical Association, 113(523):1228–1242, 2018.

[19] A. G. Wilson, Z. Hu, R. Salakhutdinov, and E. P. Xing. Deep kernel learning. In Artificial Intelligence and Statistics, pages 370–378, 2016.

[20] L. Yao, S. Li, Y. Li, M. Huai, J. Gao, and A. Zhang. Representation learning for treatment effect estimation from observational data. In Advances in Neural Information Processing Systems, pages 2633–2643, 2018.

[21] H. Zhao, R. T. d. Combes, K. Zhang, and G. J. Gordon. On learning invariant representation for domain adaptation. arXiv preprint arXiv:1901.09453, 2019.

[22] J. Zhu, N. Chen, and E. P. Xing. Bayesian inference with posterior regularization and applications to infinite latent svms. The Journal of Machine Learning Research, 15(1):1799–1847, 2014.


Designed for Accessibility and to further Open Science