Variational Learning of Individual Survival Distributions

2020Β·Arxiv

ABSTRACT

ABSTRACT

The abundance of modern health data provides many opportunities for the use of machine learning techniques to build better statistical models to improve clinical decision making. Predicting time-to-event distributions, also known as survival analysis, plays a key role in many clinical applications. We introduce a variational time-to-event prediction model, named Variational Survival Inference (VSI), which builds upon recent advances in distribution learning techniques and deep neural networks. VSI addresses the challenges of non-parametric distribution estimation by () relaxing the restrictive modeling assumptions made in classical models, and () efficiently handling the censored observations, i.e., events that occur outside the observation window, all within the variational framework. To validate the effectiveness of our approach, an extensive set of experiments on both synthetic and real-world datasets is carried out, showing improved performance relative to competing solutions.

KEYWORDS

Variational Inference, Survival Analysis, Neural Networks, Individual Personal Distribution, Time-to-event modeling, Black-box inference, Latent Variable Models

1 INTRODUCTION

Prediction of event times, also known as survival analysis in the clinical context, is one of the most extensively studied topics in the statistical literature, largely due to its significance in a wide range of clinical and population health applications. It provides a fundamental set of tools to statistically analyze the future behavior of a system, or an individual. In the classical setup, the primary goal of time-to-event modeling is to either characterize the distribution of the occurrence of an event of interest on a population level [20, 21], or more specifically, to estimate a risk score on a subject level [11]. In recent years, there has been a surge of interest in the prediction of individualized event time distributions [46].

A characteristic feature in the study of time-to-event distributions is the presence of censored instances, which refer to an event that is not reported during the follow-up period of a subject. This can happen, for instance, when a subject drops out during the study (right censoring), including when the study terminates before the event happens (administrative censoring). Unlike many conventional predictive models, where incomplete observations are usually safely ignored, censored observations contain crucial information that should be adequately considered. To efficiently leverage the censored observations, together with the complete observations, a classical treatment is to work with the notion of a hazard function, formally defined as the instantaneous event risk at time , which can be computed by contrasting the event population to the population at risk at a specific time. Estimates can be

derived, for instance by optimizing the partial likelihood defined by the relative hazards in the case of the Cox Proportional Hazard model (CoxPH) [11]. Alternatively, other work follows the standard Maximal Likelihood Estimation (MLE) framework, where the individual event distribution is a deformed version of some baseline distribution. For example, in the Accelerated Failure Time model (AFT) [20], covariate effects are assumed to rescale the temporal index of event-time distributions, i.e., they either accelerate or delay event progression. For censored events, their likelihoods are given as the cumulative density after the censoring time [1].

While vastly popular among practitioners, these models have been criticized for a number of reasons, in particular for the assumptions they make, that consequently render them unfit for many modern applications [45]. For instance, most survival models, including CoxPH and the proportional odds model [31], work under the premise of fixed covariate effects, overlooking individual uncertainty. However, it has been widely recognized that, individual heterogeneity and other sources of variation are common and often time-dependent [2]. In real-world scenarios, these random factors are typically costly to measure, if not impossible to observe. Unfortunately, many models are known to be sensitive to the violation of this fixed effect assumption, raising seriously concerns when deployed in actual practice [18].

Alternatively, machine learning techniques have been leveraged to overcome the limitations of standard statistical survival modeling schemes, especially in terms of model flexibility to address the complexity of data. For example, survival trees employed special node-splitting strategies to stratify the population and derive covariate-based survival curves [6], support vector machines [24] and neural networks [13] have been used for more expressive predictors and LASSO-type variants [47] simultaneously execute variable selection to boost statistical efficiency. Bayesian statistics has also been explored in the context of model selection [28], averaging [33] and imposing prior beliefs [14]. Recent advances in modern machine learning bring extra traction to the concept of data-driven survival models, an important step toward precision medicine. Prominent examples include direct deep learning extensions of CoxPH [22, 27], accelerated failure time [8] and Bayesian exponential family models [37]. Other efforts include the use of Gaussian Process to capture complex interactions between covariates in relation to event times [15] and competing risks [3]. It has been argued that direct modeling of the event distribution might be beneficial [46], and more recently, adversarial distribution matching has also been considered for survival applications [8] with promising results reported.

In this work we present a principled approach to address the challenges of nonparametric modeling of time-to-event distributions in the presence of censored instances. Our approach, named Variational Survival Inference (VSI), builds upon recent developments in black-box variational inference [36]. It directly targets the estimation of individualized event-time distributions, rather than a

risk score that correlates with event ordering. By explicitly accounting for latent variables in its formulation, VSI better accommodates for individual uncertainty. The proposed VSI is a highly scalable and flexible framework without strong assumptions, featuring easy implementation, stable learning, and importantly, it does not rely on ad-hoc regularizers. Our key contributions include: () a variational formulation of nonparametric time-to-event distribution modeling conditioned on explanatory variables; () a cost-effective treatment of censored observations; () a thorough discussion on how our modeling choices impact VSI performance, and () an empirical validation confirming that the proposed VSI compares favorably to its counterparts on an extensive set of tasks, covering representative synthetic and real-world datasets.

2 BACKGROUND

A dataset for survival analysis is typically composed of a collection of triplets , where indexes the subjects involved in the study. For each triplet, denotes the set of explanatory variables, is the observation time and is the event indicator. To simplify our discussion, we only consider the standard survival setup. This means is binary with 1 indicating the event of interest happened at , otherwise 0 corresponds to a censoring event, i.e., no event occurs until and the subject is unobserved thereafter. This distinction creates a natural partition of the dataset representing the censored and event groups, respectively.

2.1 Statistical survival analysis

In survival analysis, one is interested in characterizing the survival function , defined as the probability that any given subject survives until time . The basic descriptors involved in the discussion of survival analysis are: the cumulative survival density , the survival density , the hazard function and the cumulative hazard function . The following expressions are fundamental to survival analysis [1]: and . Further, we use , to denote their individualized (subject-level) counterparts given explanatory variables . All survival models leverage these definitions to derive population-level estimators or subject-level predictive functions, , or event time,

2.2 Variational inference

For a latent variable model , we consider as an observation, i.e., data, and as latent variable. The marginal likelihood, given by , typically does not enjoy a closed form expression. To avoid direct numerical estimation of , Variational Inference (VI) optimizes a variational bound to the marginal log-likelihood. The most popular choice is known as the Evidence Lower Bound (ELBO) [44], given by

where is an approximation to the true (unknown) posterior , and the inequality is a direct result of Jensenβs inequality. The variational gap between the ELBO and true log-likelihood is the KL-divergence between posteriors, i.e., KL, which implies the ELBO tightens as approaches the true posterior . For estimation, we seek parameters that maximize the ELBO. At test time, is used for subsequent inference tasks on new data. Given a set of observations sampled from data distribu- tion , maximizing the expected ELBO is also equivalent to minimizing the KL-divergence KLbetween the empirical and model distributions. When and are specified as neural networks, the resulting architecture is more commonly known as the Variational Auto-Encoder (VAE) [25] in the context of computational vision and natural language processing.

3 VARIATIONAL SURVIVAL INFERENCE

Below we detail the construction of the Variational Survival Inference (VSI) model, which results in predictions of the time-to-event distribution given attribute , with the individual uncertainty accounted in the form of a latent variablewhose distribution is estimated under the VI framework. Unlike classical survival models, we do not need to specify a parametric form for the baseline distribution, e.g., the base hazard in CoxPH [11] or the base density in AFT [20]. Instead, we leverage the power of deep neural networks to amortize the learning of the event time and survival distributions, allowing arbitrary (high-order) interactions between the predictors and survival time to be captured. This overcomes the limitations caused by the restrictive assumptions made in the classical statistical survival analysis frameworks, thus allowing flexible inference of time-to-event distributions.

3.1 Variational bound of observed events

We start the discussion with the simplest scenario, that for which there are no censoring events. Our goal is to maximize the expected log-likelihood 1. To model the conditional like- lihood, we consider a latent variable model of the form . The unconditional formulation of the ELBO in (1) can be readily generalized to case conditional on event times as

where denotes the conditional posterior approximation to the true (unknown)

In particular, we assume a model distribution with the following decomposition

which posits that is a sufficient statistics of w.r.t. survival time . Another key assumption we make is that, unlike in the standard variational inference model, we have used a learnable inhomogeneous prior for the latent to replace the standard fixed homogeneous prior . Such covariate-dependent prior formulation allows the model to account for individual variation, thus further helping to close the variational gap [43]. Replacing (3) into the ELBO expression in (2) results in the usual likelihood and KL

decomposition pair

from which we can see that maximizing the ELBO is equivalent to estimate the parameters of a probabilistic time-to-event model with maximum likelihood such that the inhomogeneous prior matches as well as possible a conditional posterior that explicitly accounts for event times, time, only will be used to make predictions provided that is not available during inference.

More specifically, and are defined as neural networks

where is represented on a discretized time line (see below for details), and are deep neural nets parameterized by model parameters and variational parameters , and denotes the multivariate Gaussian with mean and (diagonal) covariance . For standard tabular data, we use Multi Layer Perceptrons (MLPs) to specify these functions.

3.2 Variational bound of censored events

Addressing censoring in the formulation is more challenging as this type of partial observation is not subsumed in the conventional VI framework. To address this difficulty, we recall that in likelihoodbased survival analysis, the likelihood function for censored observations is given by log, where is the survival function and is the censoring time. For censored observations with 0, we do not have the exact event time . This means that we only have partial information of the events, in that the event should happen only after the censoring time

To derive a tractable objective for censored observations, we first expand based on its definition and an application of Fubiniβs theorem [38] and Jensenβs inequality, i.e.,

where the censored log-likelihood bound ELBOis only evaluated on , the subset of censored observations. See Supplementary Materials for the full derivation of ELBO

3.3 Implementing VSI

In the current instantiation of the model, we discretize time into bins spanning the time horizon of the (training) data. This means that (at inference) is only known up to the time bin it falls into. We

note this is not a restrictive assumption as many survival data is only known up to certain temporal accuracy. That said, generalization to continuous observations is fairly straightforward. For datasets that do have a natural discretization, we leave the choice to the user. In this study, we partition the temporal index based on the percentiles of observed event time, while also allowing for an artificial 1)-th bin to account for event times beyond the full observation window, i.e., events happening after the end-of-study as observed in the training cohort.

Since both are assumed to be Gaussian, the following closed-form expression can be used in the computation of the KL terms above

Following Ranganath et al. [36], we use diagonal covariance matrices and apply the reparameterization trick to facilitate stable differatiable learning.

In order to compute the term , we use discretized time scheme as previously described, and sum up all predicted probabilities subsequent to bin . Note that this can be readily generalized to continuous time models. So long as the cumulative distribution of enjoys a closed form expression, a numerical integration scheme is not necessary to implement VSI.

3.4 Importance-Weighted estimator for likelihood evaluation

For evaluation purposes, we need to be able to compute the modelβs log-likelihood for an observation

In this study, we use the importance-weighted (IW) estimator [7], which provides a tighter bound to the log-likelihood. While more sophisticated alternatives might provide sharper estimates [32], we deem IW estimator sufficient for the scope of this study. Additionally, while the tighter bound can be repurposed for training, it does not necessarily result in improved performance [35], which we find to be the case in this study.

To obtain a more accurate value of the likelihood, we use the approximate posterior as our proposal, and use the following finite sample estimate

where is the number of samples. The log-likelihood for the corresponding conditional survival function is

Note that by nature of Jensenβs inequality, the resultant estimand will be an under-estimation of the true log-likelihood. As infinity, the approximated lower bound will converge to the true log-likelihood.

3.5 Making Predictions

Predictive time-to-event distribution During inference, given a new data point with, according to the generative model, where the integration is conducted numerically by Monte Carlo sampling.

Point estimation of time-to-event To better exploit the learned approximated posterior , we generalize the importance sampling idea and provide a weighted average as time-to-event summary, rather than for instance using a summary statistic such as median or mean. Specifically, consider multiple samples of , then calculate a weighted average as

In the Supplementary Materials we show that (8) gives better model performance for point-estimate-based evaluation metrics, Concordance Index in particular, compared to other popular summary statistic such as the median of empirical samples.

4 DISSECTING VSI

In the experiments, we show the effectiveness of the proposed VSI model in recovering underlying time-to-event distributions. To provide additional insight into the differentiating components of the VSI model, we consider two baseline models that partially adopt a VSI design, as detailed below.

In VSI, we use the variational lower bound to maximize the likelihood in survival studies by implicitly forcing the unknown intractable model posterior to be close to the tractable posterior approximation the KL divergence minimization, such matching allows the model to better account for interactions between covariates and event times captured by to better inform the construction of the latent representation via isolating out the individual uncertainty encoded by . If we exclude the interaction term in and only make the prediction with , with the approximate posterior given by , through the same stochastic latent representation , then naturally the optimal solution is to equate with the prior . This basically eliminates our formulation, and therefore we call this variant VSI-NoQ.

More specifically, without a arm the model described in Section 3 essentially becomes a feed-forward model with a special stochastic hidden layer . In this case, the model likelihood is given by , where and are defined as in (3). Note that the only difference with VSI is the lack of the KL divergence term to match . This baseline model (VSI-NoQ) is considered to dissect the impact of excluding complex interaction between covariates and event time when constructing the individualized priors.

Deterministic feed-forward model (MLP) To understand the importance of the stochastic latent representations , we consider a straightforward baseline which directly predicts the event time distribution based on the input , which is essentially a standard multinomial regression with censored observation. In our study, we use the MLP to implement . And as such, hereafter we will refer to this model as MLP. Additionally, we also considered standard randomization schemes, such as dropout [42], in the construction of a stochastic neural net, which promises to improve performance. Such strategy also incorporates randomness, however differs principally from the modeled uncertainty exploited by our VSI scheme. In our experiment section, we report the best results from MLP with or without dropout.

These baseline approaches use feed-forward deep learning networks to learn without incurring the notation of variational inference. In the experiments we will show that the variational inference is crucial to the accurate learning of time-to-event distributions, resulting in better performance relative to these baselines, especially when the proportion of censoring events is high.

5 RELATED WORK

Machine learning and survival analysis Early attempts of combining machine learning techniques with statistical survival analysis, such as the Faraggi-Simon network (FS-network) [13], often failed to demonstrate a clear advantage over classical baselines [40]. Recent progresses in machine learning allow researchers to overcome the difficulties suffered by prior studies. For example, Katzman et al. [23] showed that weight decay, batch normalization and dropout significantly improved the performance of FS-network. Li et al. [27] analyzed survival curves based on clinical images using deep convolution neural net (CNN). In addition to deep nets, FernΓ‘ndez et al. [15] showed that Gaussian Process can be used to effectively capture the non-linear variations in CoxPH models, and Alaa and van der Schaar [3] further proposed a variant that handles competing risks. Similar to these works, our VSI also draws power from recent advances in machine learning to define a flexible learner.

Bayesian survival analysis Bayesian treatment of survival models has a long history. Raftery et al. [34] first considered modeling uncertainties for survival data, Zupan et al. [49] reported probabilistic analysis under Bayesian setup. More recently, Fard et al. [14] exploited the Bayesian framework to extrapolate priors, and Zhang and Zhou [48] described a Bayesian treatment of competing risks. Closest to VSI is the work of deep exponential family model (DEF) survival model [37], where the authors introduced a Bayesian latent variable model to model both predictors and survival time . Unlike our VSI, DEF still imposes strong parametric assumptions on the survival distribution, and itβs not clear how the censored observations are handled in DEFβs actual implementation. Another key difference between DEF and VSI is the factorization of joint likelihood. As the VSI encoder will only seek to capture the latent components that are predictive of the survival time distribution, while DEF encoder also needs to summarize information required to reconstruct covariates . We argue that our VSI factorization of joint probability is more sensible for survival time modeling, because modeling not only adds model complexity but also introduces nuisance to the prediction of survival time . For datasets with large covariates dimensions and noisy observations, the DEF features can be dominated by the ones predictive of rather , compromising the main goal of modeling the survival distribution.

Individual uncertainties and randomization The seminal work of Aalen [2] first identified importance of accounting for the individual uncertainties, the main culprit for the failure of classical survival models, which can be remedied by explicitly modeling the random effects [18]. Alternatively, Ishwaran et al. [19] presented Random Survival Tree (RST) to predict cumulative hazards using a tree ensemble, demonstrating the effectiveness of a randomization scheme for statistical survival models. Our approach differs from the above schemes by systematically account for individual uncertainty using the randomness of latent variables.

Direct modeling of survival distribution The pioneering work of Yu et al. [46] advocated the prediction of individual survival distributions, which is learned using a generalized logistic regression scheme. This idea is further generalized in the works of Luck et al. [29] and Fotso [16]. Recently, Chapfuwa et al. [8] explored the use of deep Generative Adversarial Network (GAN) to capture the individual survival distribution, which is closest to our goal. Compared the proposed VSI, the adversarial learning of survival distribution is largely unstable, and its success crucially relies on the use of ad-hoc regularizers.

6 EXPERIMENTS

To validate the effectiveness of the proposed VSI, we benchmarked its performance against the following representative examples from both statistical and machine learning survival analysis schemes: AFT-Weibull, CoxPH, LASSO-based CoxNet [41], Random Survival Forest (RSF) [19] and deep learning based DeepSurv [23]. To fully appreciate the gains from using a variational setup, we further compared the results with the baselines discussed in Section 4, namely, the feed-forward model (MLP) and VSI model without the backward encoding arm (VSI-NoQ).

For data preparation, we randomly partition data into three nonoverlapping sets for training (60%), validation (20%) and evaluation (20%) purposes respectively. All models are trained on the training set, and we tune the model hyper-parameters wrt the out-of-sample performance on the validation set. The results reported in the paper are based on the evaluation set using best-performing hyper-parameters determined by the validation set. We apply ADAM optimizer with learning rate of 5 during training, with minibatches of size 100. The early stopping criteria of no improvement on the validation datasets is enforced.

To ensure fair comparisons, all deep-learning based solutions are matched for the number parameters and similar model architectures & similar hyper-parameter settings. TensorFlow code to replicate our experiments can be found at https://github.com/ZidiXiu/VSI/. The details of the VSI model setups are related to the Supplementary Materials (SM).

6.1 Evaluation Metrics

To objectively evaluate these competing survival models, we report a comprehensive set of distribution-based and point-estimate based scores to assess model performance, as detailed below.

Concordance Index (C-Index) is commonly used to evaluate the consistency between the model predicted risk scores and observed event rankings [17]. Formally, it is defined as

, where is the set of all valid ordered pairs (event before event is a scalar prediction made by the model. Higher is better.

Time-dependent Concordance Index is a distribution generalization of the scalar risk score based C-Index [4], which is computed from the predicted survival distribution. Formally it is given by

, where Λdenotes the model predicted cumulative survival function. We report the results using the following empirical estimator

Kolmogorov-Smirnov (KS) distance For synthetic datasets, we also report the KS distance [30] between the predicted distribution and the ground truth. KS computes the maximal discrepancy between two cumulative densities, i.e.,

and a lower KS indicates better match of two distributions.

Test log-likelihood We also report the average log-likelihood on the held-out test set. A higher score indicates the model is better aligned with the ground-truth distribution in the sense of KL-divergence. Additionally, we also evaluate the spread of empirical likelihood wrt the models. In the case of an expected log-likelihood tie, models with the more concentrated log-likelihoods are considered better under the maximal entropy principle [9] (i.e., as observed instances received more uniform/similar likelihoods, better generalization of the model is implied).

Coverage Rate To quantify the proportion of observed time covered in the predicted personalized time-to-event distributions, we calculated the coverage rate for different percentile ranges. For subjects with event observations, the coverage rate is defined as the proportion of observations fall in the percentile ranges the predicted distributions, where respectively denotes lower and upper quantile of percentile ranges, i.e.,

In our experiments, we report coverage rates of events at percentile range , [0.25, 0.75], [0.3, 0.7], [0.35, 0.65], [0.4, 0.6], [0.45, 0.55]} of the predicted personalized distributions. For censoring, we calculate the proportion of the censoring time happened before the percentiles

Figure 1: Two simulated time-to-event distributions with 30% event rate showing that VSI successfully predicts the underlying distributions from covariates. (left: events, right:censoring)

of predicted range, since the true time-to-event for censoring is happened after censoring time,

We evaluated the coverage rate for censoring atpercentiles.

For all coverage rates, a higher score implies better performance. Coverage rates for events and censoring should be considered together to evaluate model performance.

6.2 Synthetic datasets

Following Bender et al. [5] we simulate a realistic survival data based on the German Uranium Miners Cohort Study in accordance with the Cox-Gompertz model

, with . This model simulates the cancer mortality associated with radon exposure and age. Model parameter are derived from real data: 2138, 1015 and 001. Covariates are generated according to

where denotes a normal distribution with mean and variance . We simulate uniform censoring within a fixed time horizon . By setting different upper bounds for censoring, we achieve different observed event rates, 100%and 30%. For each simulation we randomly draw iid samples.

Prediction of subject-level distribution In practice, for each subject we only observe one from its underlying distribution. Our goal is to accurately predict the underlying distribution from the covariates alone (since and are not observed at test time), by learning from the observed instances. Figure 1 compares our VSI prediction with the ground-truth for two random subjects, which accurately recovers of individual survival distribution for both observed (Figure 1(a)) and censored cases (Figure 1(b)).

To systematically evaluate the consistency between the predicted and the true distributions, we compare average KS distance from

Table 1: KS statistic for simulation study.

models trained with various event rates in Table 1. Since the underlying generative process is based on CoxPH model, we consider the results from CoxPH as the oracle reference, since there is no model mis-specification. At 100% event rate (i.e., complete observation), apart from the oracle CoxPH, all models perform similarly. The VSI variants give slightly better results compared with MLP and AFT-Weibull. As the proportion of observed events decreases, VSI remains the best performing model, closely followed by the parametric AFT-Weibull. Note that neither MLP nor VSI-NoQ matches the performance of VSI, which suggests that the full VSI design better accommodates censoring observations.

Average log-likelihood and C-Index To validate the effectiveness of VSI, we also provide a comprehensive summary of model performance against other popular or state-of-the-art alternatives in Table 2, under various simulation setups with different evaluation metrics. VSI consistently outperforms its counterparts in terms of the average log-likelihood and time-dependent C-Index. Together with the observation that VSI also yields better KS distance (see Table 1), converging evidence suggests our VSI better predicts the individual survival distributions relative to other competing solutions.

We also compared the raw C-Index and the corresponding confidence intervals using the weighted average of model predicted survival time (defined in Sec 3.5) as the risk score, and we did not find significant differences between alternative methods, as shown in Table 2 and Supplemental Materials. Thus VSI can deliver comparable performance relative to models that are compatible with the data generating mechanism. Raw C-Index quantifies the corresponding pairs without considering the time horizon, thus the distinctions among good performing models are not significant.

To provide a more informative summary, We plot the test log-likelihood distributions for selected models in Figure 2. We can see that VSI log-likelihoods estimates are tighter and higher for both observed and censored observations, especially when we have low event rates. The (0.10, 0.90) percentiles range for simulation studies please refer to SM.

Coverage Plots In Figure 3, VSI achieves both relatively high coverage for event (Figure 3(a)) and censored observations (Figure 3(b)), comparing to the oracle method CoxPH in this synthetic example. Note that while RSF performs better for the observed events, its performance on censored cases falls well below other solutions.

We refer the readers to our Supplementary Materials for additional simulations and analyses based on toy models.

Table 2: Model performance summary for simulation study based on , C-Index and average test log-likelihood. Confidence Intervals for C-Index provided in the SM. For NA entries, the corresponding evaluation metric can not be applied.

Figure 2: Test log-likelihood distributions for the 50% event rate simulation dataset. (left: events, right:censoring)

Figure 3: Test coverage rate for the 50% event rate simulation dataset. (left: events, right: censoring)

6.3 Real-World datasets

Moving beyond toy simulations, we further compare VSI to competing solution on the following three real-world datasets, [12]: a public dataset to determine whether the elevation in free light chain assay provides prognostic information to the general population survival, ]: a public dataset for a prospective cohort study to estimate survival over seriously ill hospitalized adults for 180 days period, and [39]: a public dataset aim to study cancer survival among adults, which contains 1988 to 2001 information, provided by U.S. Surveillance, Epidemiology, and End Results (SEER) Program. In this experiments, we used 10-year follow-up breast cancer subcohort in SEER dataset. We follow the

data pre-processing steps outlined in Chapfuwa et al. [8]. To handle the missing values in data, we adopt the common practice of median imputation for continuous variables and mode imputation for discrete variables.

Table 3: Summary Statistics for Real Datasets.

Summary statistics of the datasets are shown in Table 3, where is the total number of observations, denotes the total number of variables after one-hot-encoding, NaN(%) stands for the proportion of missingness in covariates, and loss of information stands for the proportion of censoring observations happened after the maximum event time

In Table 4 we compare the C-Indices and average log-likelihood. The advantage of VSI is more evident for the more challenging real datasets, especially in the cases of low observed event rates. For example, with 30% event rate, in SUPPORT dataset, VSI Confidence Interval for raw C-Index as (0.809, 0.846), while the standard CoxNet is only (0.763,0.805) and AFT (0.782,813), i.e., the overlaps with that of VSI are very small. Similar results were observed for other datasets and baseline solutions. VSI shows remarkable robustness against data incompleteness in a real-world scenario, achieving the best results according to all three metrics. For VSI the raw CIndex is computed from the weighted average of VSI predicted distribution, please refer to SM for more details. In Figure 4, the distribution of log-likelihood is more concentrated, in addition to a higher mean. To quantitatively evaluate the concentration, we report the difference between 10% and 90% quantiles of log-likelihood in Table 5. The quantile ranges of VSI are considerably smaller compared to alternative solutions under most experimental settings. This verifies VSI enjoys better model robustness compared to other popular alternatives, especially in the case of high censoring rates.

Table 4: Summary for Real Datasets based on C-Index and average log-likelihood. Confidence Intervals for C-Index are pro- vided in SM. NA implies the corresponding evaluation metric can not be evaluated.

Table 5: Quantile ranges for log-likelihood in Real Datasets. Note AFT did not converge to reasonable solutions for SEER.

Figure 4: log-likelihood distributions for SUPPORT Dataset, (left: events, right:censoring)

Figure 5: Coverage rate for SUPPORT Dataset, (left: events, right: censoring)

Together with the coverage plots in Figure 5, VSI has relative high coverage for both events and censoring cases which indicates better performance in capturing the true event time in challenging real-world datasets. The consistency of those results have been

verified through repeated runs on these three datasets. For more detailed results please refer to SM.

7 CONCLUSIONS

We presented an approach for learning time-to-event distributions conditioned on covariates in a nonparametric fashion by leveraging a principled variational inference formulation. The proposed approach, VSI, extends the variational inference framework to survival data with censored observations. Based on synthetic and diverse real-world datasets, we demonstrated the ability of VSI to recover the underlying unobserved time-to-event distribution, as well as providing point estimations of time-to-event for subjects that yield excellent performance metrics consistently outperforming feed-forward deep learning models and traditional statistical models.

As future work, we plan to extend our VSI framework to longitudinal studies, where we can employ a recurrent neural net (RNN) to account for the temporal dependencies. For datasets with observations made at irregular intervals, for instance, the Neural-ODE model [10] can be applied. Our work can be also adapted to make dynamic predictions of event times to serve the needs of modern clinical practices.

Acknowledgements. The authors would like to thank the anonymous reviewers for their insightful comments. This research was supported in part by by NIH/NIBIB R01-EB025020.

REFERENCES

[1] Odd Aalen, Ornulf Borgan, and Hakon Gjessing. 2008. Survival and event history analysis: a process point of view. Springer Science & Business Media.

[2] Odd O Aalen. 1994. Effects of frailty in survival analysis. Statistical Methods in Medical Research 3, 3 (1994), 227β243.

[3] Ahmed M Alaa and Mihaela van der Schaar. 2017. Deep multi-task gaussian processes for survival analysis with competing risks. In Proceedings of the 31st International Conference on Neural Information Processing Systems. Curran Associates Inc., 2326β2334.

[4] Laura Antolini, Patrizia Boracchi, and Elia Biganzoli. 2005. A time-dependent discrimination index for survival data. Statistics in Medicine 24, 24 (2005), 3927β 3944.

[5] Ralf Bender, Thomas Augustin, and Maria Blettner. 2005. Generating survival times to simulate Cox proportional hazards models. Statistics in medicine 24, 11 (2005), 1713β1723.

[6] Imad Bou-Hamad, Denis Larocque, Hatem Ben-Ameur, et al. 2011. A review of survival trees. Statistics Surveys 5 (2011), 44β71.

[7] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. 2015. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519 (2015).

[8] Paidamoyo Chapfuwa, Chenyang Tao, Chunyuan Li, Courtney Page, Benjamin Goldstein, Lawrence Carin, and Ricardo Henao. 2018. Adversarial time-to-event modeling. arXiv preprint arXiv:1804.03184 (2018).

[9] Liqun Chen, Chenyang Tao, Ruiyi Zhang, Ricardo Henao, and Lawrence Carin Duke. 2018. Variational inference and model selection with generalized evidence bounds. In International Conference on Machine Learning. 892β901.

[10] Tian Qi Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. 2018. Neural ordinary differential equations. In Advances in neural information processing systems. 6571β6583.

[11] David R Cox. 1972. Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological) 34, 2 (1972), 187β202.

[12] Angela Dispenzieri, Jerry A Katzmann, Robert A Kyle, Dirk R Larson, Terry M Therneau, Colin L Colby, Raynell J Clark, Graham P Mead, Shaji Kumar, L Joseph Melton III, et al. 2012. Use of nonclonal serum immunoglobulin free light chains to predict overall survival in the general population. In Mayo Clinic Proceedings, Vol. 87. Elsevier, 517β523.

[13] David Faraggi and Richard Simon. 1995. A neural network model for survival data. Statistics in medicine 14, 1 (1995), 73β82.

[14] Mahtab Jahanbani Fard, Ping Wang, Sanjay Chawla, and Chandan K Reddy. 2016. A bayesian perspective on early stage event prediction in longitudinal data. IEEE Transactions on Knowledge and Data Engineering 28, 12 (2016), 3126β3139.

[15] Tamara FernΓ‘ndez, NicolΓ‘s Rivera, and Yee Whye Teh. 2016. Gaussian processes for survival analysis. In Advances in Neural Information Processing Systems. 5021β 5029.

[16] Stephane Fotso. 2018. Deep Neural Networks for Survival Analysis Based on a Multi-Task Framework. arXiv preprint arXiv:1801.05512 (2018).

[17] Frank E Harrell, Robert M Califf, David B Pryor, Kerry L Lee, and Robert A Rosati. 1982. Evaluating the yield of medical tests. Jama 247, 18 (1982), 2543β2546.

[18] Philip Hougaard. 1995. Frailty models for survival data. Lifetime data analysis 1, 3 (1995), 255β273.

[19] Hemant Ishwaran, Udaya B Kogalur, Eugene H Blackstone, Michael S Lauer, et al. 2008. Random survival forests. The annals of applied statistics 2, 3 (2008), 841β860.

[20] John D Kalbfleisch and Ross L Prentice. 2011. The statistical analysis of failure time data. Vol. 360. John Wiley & Sons.

[21] Edward L Kaplan and Paul Meier. 1958. Nonparametric estimation from incomplete observations. Journal of the American statistical association 53, 282 (1958), 457β481.

[22] Jared L Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. 2016. Deep survival: A deep cox proportional hazards network. stat 1050 (2016), 2.

[23] Jared L Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. 2018. DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC medical research methodology 18, 1 (2018), 24.

[24] Faisal M Khan and Valentina Bayer Zubek. 2008. Support vector regression for censored data (SVRc): a novel tool for survival analysis. In 2008 Eighth IEEE International Conference on Data Mining. IEEE, 863β868.

[25] Diederik P Kingma and Max Welling. 2013. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013).

[26] William A Knaus, Frank E Harrell, Joanne Lynn, Lee Goldman, Russell S Phillips, Alfred F Connors, Neal V Dawson, William J Fulkerson, Robert M Califf, Norman Desbiens, et al. 1995. The SUPPORT prognostic model: objective estimates of survival for seriously ill hospitalized adults. Annals of internal medicine 122, 3 (1995), 191β203.

[27] Hongming Li, Pamela Boimel, James Janopaul-Naylor, Haoyu Zhong, Ying Xiao, Edgar Ben-Josef, and Yong Fan. 2019. Deep Convolutional Neural Networks for Imaging Data Based Survival Analysis of Rectal Cancer. arXiv preprint arXiv:1901.01449 (2019).

[28] Paulo JG Lisboa, H Wong, P Harris, and Ric Swindell. 2003. A Bayesian neural network approach for modelling censored data with an application to prognosis after surgery for breast cancer. Artificial intelligence in medicine 28, 1 (2003), 1β25.

[29] Margaux Luck, Tristan Sylvain, HΓ©loΓ―se Cardinal, Andrea Lodi, and Yoshua Bengio. 2017. Deep learning for patient-specific kidney graft survival analysis. arXiv preprint arXiv:1705.10245 (2017).

[30] Frank J Massey Jr. 1951. The Kolmogorov-Smirnov test for goodness of fit. Journal of the American statistical Association 46, 253 (1951), 68β78.

[31] SA Murphy, AJ Rossini, and Aad W van der Vaart. 1997. Maximum likelihood estimation in the proportional odds model. J. Amer. Statist. Assoc. 92, 439 (1997), 968β976.

[32] Radford M Neal. 2001. Annealed importance sampling. Statistics and computing 11, 2 (2001), 125β139.

[33] Adrian E Raftery. 1995. Bayesian model selection in social research. Sociological methodology 25 (1995), 111β164.

[34] Adrian E Raftery, David Madigan, and Chris T Volinsky. 1996. Accounting for model uncertainty in survival analysis improves predictive performance. Bayesian statistics 5 (1996), 323β349.

[35] Tom Rainforth, Adam R Kosiorek, Tuan Anh Le, Chris J Maddison, Maximilian Igl, Frank Wood, and Yee Whye Teh. 2018. Tighter variational bounds are not necessarily better. arXiv preprint arXiv:1802.04537 (2018).

[36] Rajesh Ranganath, Sean Gerrish, and David Blei. 2014. Black box variational inference. In Artificial Intelligence and Statistics. 814β822.

[37] Rajesh Ranganath, Adler Perotte, NoΓ©mie Elhadad, and David Blei. 2016. Deep survival analysis. arXiv preprint arXiv:1608.02158 (2016).

[38] Sidney Resnick. 2003. A probability path. Birkhauser Verlag AG.

[39] LA Gloeckler Ries, JL Young, GE Keel, MP Eisner, YD Lin, MJ Horner, et al. 2007. SEER survival monograph: cancer survival among adults: US SEER program, 1988-2001, patient and tumor characteristics. National Cancer Institute, SEER Program, NIH Pub 07-6215 (2007), 193β202.

[40] Guido Schwarzer, Werner Vach, and Martin Schumacher. 2000. On the misuses of artificial neural networks for prognostic and diagnostic classification in oncology. Statistics in medicine 19, 4 (2000), 541β561.

[41] Noah Simon, Jerome Friedman, Trevor Hastie, and Rob Tibshirani. 2011. Regularization paths for Coxβs proportional hazards model via coordinate descent. Journal of statistical software 39, 5 (2011), 1.

[42] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. 2014. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research 15, 1 (2014), 1929β1958.

[43] Jakub M Tomczak and Max Welling. 2017. VAE with a VampPrior. arXiv preprint arXiv:1705.07120 (2017).

[44] Martin J. Wainwright and Michael I. Jordan. 2008. Graphical Models, Exponential Families, and Variational Inference. Foundations and Trends in Machine Learning 1 (2008), 1β305.

[45] Ping Wang, Yan Li, and Chandan K Reddy. 2019. Machine learning for survival analysis: A survey. ACM Computing Surveys (CSUR) 51, 6 (2019), 110.

[46] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin, and Vickie Baracos. 2011. Learning patient-specific cancer survival distributions as a sequence of dependent regressors. In Advances in Neural Information Processing Systems. 1845β1853.

[47] Hao Helen Zhang and Wenbin Lu. 2007. Adaptive Lasso for Coxβs proportional hazards model. Biometrika 94, 3 (2007), 691β703.

[48] Quan Zhang and Mingyuan Zhou. 2018. Nonparametric Bayesian Lomax delegate racing for survival analysis with competing risks. In Advances in Neural Information Processing Systems. 5002β5013.

[49] BlaΕΎ Zupan, Janez DemΕ‘ar, Michael W Kattan, J Robert Beck, and Ivan Bratko. 1999. Machine learning for survival analysis: a case study on recurrence of prostate cancer. In Joint European Conference on Artificial Intelligence in Medicine and Medical Decision Making. Springer, 346β355.

A IMPLEMENTING VSI

For model inputs, we encode each subjectβs time with a one-hot vector with dimension based on the observed event for subject . With event has been one-hot-encoded, only one bin where the event time falls in equals 1, else equal 0. For censoring observations we have partial information about the true unknown event time, instead of missingness. We know for sure that the events would happen after the observed censored time . The input vector, timepoints after the censored time been put a prior based on population time-to-event distribution by Nelson-Aalen Estimator. This re-weighting strategy informs the model that this subject didnβt have event before time , and also regularizes the model with respect to empirical tail distribution. One may also choose not to encode but adding to the input as well. In this case, each subject in the training has input of the form . To predict new , the input for the prediction process is where by defaut we are predicting the time-to-event. We have shown by experiments the two form of inputs produce similar results and the encoding strategy performs better under both simulation and real-world scenarios. The KS statistic for 30% event rates has 0.073, well VSI has 0.059.

Population Survival Estimation Based on Nelson-Aalen Estimator For NA estimator, Λ. Therefore we could calculate accordingly. Let the selected discretized time be where

Therefore the we have accordingly for each timepoint. For each censored subject where the censored time falls in bin

In this way the constructed 1 as well. Now we have the soft one-hot-encoding for original censored

Architecture for VSI The input of VSI is . Parameters for and were obtained by two MLP framework respectively, each with hidden dimensions [32, 32] with output dimension [32]. The decoding arm has hidden dimensions [32, 32, 32] with output a vector which is the predicted logits. Activation function is leaky ReLU, continuous variables are z-transformed before entering the model based on training datasets.

B DERIVATION OF LIKELIHOOD LOWERBOUND FOR CENSORED OBSERVATIONS

To get the likelihood lowerbound for censored observations, we applied Fubiniβs theorem and Jensenβs inequality.

Likelihood for Multivariate Normal Distribution The Multivariate Normal Log-likelihood function used for ELBO has

C LIKELIHOOD CALCULATION FOR STATISTICAL MODELS

AFT Models Distributions of and are corresponding to each other, with:

The likelihood function under general non-informative censoring has the form:

For AFT model with Weibull distribution, follows the extreme value distribution andfollows the Weibull distribution, with scale parameter 0, shape parameter

Thus in the extreme value distribution , equivalently . Thus the log-likelihood is:

Note that for comparison log-likelihood at same scale, AFT and other parametric methodsβ log-likelihood were calculated at discretized event time same as VSI.

Then based on the survival likelihood function

, we could get the likelihood for RSF.

Since cox model is semiparametric, we donβt know the full pa-

rameters, therefore usually we would give partial likelihood instead

of full likelihood.

is the set of subjects at risk at time . Above partial likelihood is from multiply the probability that a subject with covariates dies in with 0 for each subject with event. Λis obtained by Maximize log

To make comparison between difference method, we would use some empirical way to calculate the full likelihood for Cox models.

In above equations, we would use the estimated and predicted

D ADDITIONAL RESULTS FOR SIMULATION DATASETS

Distribution of log-likelihood The distributions of log-likelihoods for simulation datasets are shown in Figure S1. In all scenarios, VSI has the most concentrated distribution of the log-likelihood, which suggests that VSI is robust under different event rate with CoxPH assumptions. MLP and VSI-NoQ performed equally good in this simple synthetic study. The percentile ranges of log-likelihood in

Figure S1: Testing datasets log-likelihood distributions for simulation dataset. 100% event rate (a), 50% event rate (observed:b, censored:c), 30% event rate (observed:d, censored:e)

simulation studies are shown in Table S1. When we have low event rates, VSI stands out from the baseline methods.

Table S1: Quantile range [0.10, 0.90] of log-likelihood in Simulation Study

Confidence Intervals for Raw C-Index To better capture the differences in raw C-Index, the confidence intervals for each simulation strategy are calculated (See Table S2). We have comparable performance with oracle methods with regarding to this metric. More illustrative metrics have been discussed in the text.

Table S2: Raw C-Index with confidence intervals (in paren- theses) for simulation studies. RSF and DeepSurv do not provide intrinsic methods to calculate confidence intervals

Coverage Rate To visualize the proportion of observed time is covered in the predicted personalized time-to-event distributions, we calculated the coverage rate for each percentile ranges. Our model balanced between events and censoring coverage the best among all methods. We compared the coverage with CoxPH, AFT-Weibull, VSI-NoQ, MLP and RSF. CoxPH could serves as the reference, since the simulation study is based on CoxPH assumptions entirely. The results are shown in Figure S2.

E ADDITIONAL RESULTS FOR REAL DATASETS

Comparison of Weighted Average and Median In Table S3, we compared the C-Index calculated with weighted average and median for real datasets, with weighted average higher in general. In simulation studies, those two numbers are similar to each other.

Figure S2: Testing datasets Coverage rate for simulation datasets. 100% event rate (a), 50% event rate (observed:b, censored:c), 30% event rate (observed:d, censored:e)

Table S3: Comparison median and weighted average for predicted time-to-event distribution with C-index in real datasets

Distribution for Log-likelihood Log-likelihood distribution for real datasets shown in Figure S3. Similar to the case discussed in the text, VSI performs consistently better in all three realworld datasets.

Confidence Intervals for Raw C-Index As shown in Table S4, raw C-Index for VSI is significantly better than other method with relatively tight confidence intervals.

Coverage rate for real datasets Cover rate for real datasets shown in Figure S4. VSI has relative high coverage for both events and censoring in all three datasets, which similar to the case we have in simulation datasets.

For the comparison models, VSI, MLP and VSI-NoQ and AFTWeibull are distribution based and could give all statistics in the simulation studies. For RSF and CoxPH, which could calculate the estimated cumulative hazards, the corresponding survival function

Table S4: Raw C-Index with confidence intervals (in paren- theses) for Real Data. NA indicates the corresponding evaluation metric cannot be evaluated. RSF and DeepSurv do not provide intrinsic methods to calculate confidence intervals.

Figure S4: Coverage rates for real datasets. FLCHAIN (observed: a, censored: b), SUPPORT (observed: c, censored: d), SEER (observed: e, censored: f)

Figure S3: Likelihood distributions for real datasets. FLCHAIN (observed: a, censored: b), SUPPORT (observed: c, censored: d), SEER (observed: e, censored: f)

and time-to-event could be calculated accordingly. For DeepSurv and Coxnet, only raw C-Index can be given.