Time-to-event models have primarily focused on either estimating a (point estimate) risk score or individualized time-to-event distributions. Parametric models estimate the time-to-event distribution conditional on covariates by assuming a parametric form of the event distribution, i.e, exponential, Weibull, log-normal, etc. Parametric models fall under the Accelerated Failure Time (AFT) [50] framework, provided they assume covariates either accelerate or decelerate the time-to-event. Assuming a parametric distribution is inflexible as the hazard function depends on the selected baseline distribution, for example assuming an exponential distribution, yields a constant hazard rate function. Alternatively, Cox Proportional Hazards (CoxPH) [14], a semi-parametric, linear model for estimating relative risks is widely used in practice, as it does not require one to specify the baseline distribution. For pre-specified time-horizons, the non-parametric Random Survival Forest (RSF) [27] was proposed to estimate the cumulative hazard function based on an ensemble of binary decision trees, albeit often limited by scaling problems for large and high-dimensional datasets.
With recent advances in machine learning, deep learning methods have improved classical survival analysis methods by leveraging non-linear relationship between covariates, for improved time-to-event or risk score predictions. Deep learning methods inspired by CoxPH or AFT have been proposed, e.g., DeepSurv [30], Deep Survival Analysis (DSA) [44], Deep Regularized Accelerated Failure Time (DRAFT) [11], Gaussian-process-based models [4, 19], and the Survival Continuous Ranked Probability Score (S-CRPS) [5]. Sampling-based nonparametric methods have been proposed as well, e.g., normalizing-flow-based DSA [39], adversarial-learning-based Deep Adversarial Time to Event (DATE) [11] and Survival Function Matching (SFM) [12]. Another class of nonparametric methods discretize time-to-event to predict survival probability within pre-specified discrete-interval event times with logistic-regression-based methods [20, 35, 53]. Further, deep learning methods have also successfully addressed calibration [5, 12, 36] and competing risks [4, 54].
Clustering based on risk-profiles in survival analysis is relatively under-explored in machine learning, but is critical in applications such as (clinical) decision making. Identifying phenotypically heterogeneous subpopulations in the context of risk prediction is an important step toward machine-learning-based models for precision medicine [13, 16]. Existing clustering methods for stratifying risks in survival analysis include feature based K-means (see Figure 1(a)) or hierarchical clustering [2, 17, 48]. Principal component cluster analysis has also been considered [3]. However, it is well understood that feature-based clustering in covariate space may produce clusters that are not consistent with survival outcomes [6, 21], particularly for high-dimensional datasets, such as gene expression data.
Methods that account for survival outcomes in clustering include CoxPH-inspired techniques [6, 21], implemented as a twostep process: first, high CoxPH scoring covariates are selected, then a classical clustering approach like K-means is applied (see Figure 1(b)). However, CoxPH-based approaches are limited by the proportional hazards assumption. Alternatively, Xia et al. [51] proposed an outcome driven attention-based multi-task deep learning model for classification and then applied K-means on the latent representations to cluster subjects with acute coronary syndrome. More recently, Mouli et al. [40] introduced DeepCLife, a method that learns clusters by maximizing the pairwise differences between the survival functions of all cluster pairs. This is done by indirectly maximizing the logrank score [38]. Unlike DeepCLife, which aims to optimize clusters but not predictions, our goal is to jointly characterize time-to-event predictive distributions from a clustered latent space conditioned on covariates (see Figure 1(c)).
We propose a model for time-to-event predictions equipped with a structured latent representation that allows for clustering via a prior for infinite mixture of distributions. We circumvent the challenges associated with infinite mixtures in stochastic learning by leveraging a truncated Dirichlet process (DP) with a stick breaking representation. The proposed model, termed Survival Clustering Analysis (SCA), is specified as: i) a deterministic encoder that maps
Figure 1: Cluster-specific Kaplan-Meier survival profiles for three clustering methods on the sleep dataset (See Section 4 for details). (a) StandardK-means. (b) CoxPH-based covariate selection followed byK-means. (c) Proposed approach for joint learning of individualized time-to-event predictions and clustering. By jointly learning clustering with respect to both the covariates x and predicted time-to-event t, our model (SCA) can identify high-, medium- and low-risk individuals. Demonstrating the need to account for time information via a non-linear transformation of covariates when clustering survival datasets.
covariates into a latent representation; ii) a stochastic time-to-event predictor that feeds from the latent representation; and iii) a distribution matching objective that encourages latent representations to behave as a mixture of distributions following a DP structure. This approach allows identification and analysis of phenotypically heterogeneous subpopulations. Our experiments demonstrate that SCA yields consistent improvements in predictive performance and cluster quality relative to existing methods.
In a conventional time-to-event (survival analysis) setup, we are given N observations. Individual observation are described by triplets -dimensional vector of covariates,
is the time-to-event and
is the censoring indicator. When
0 (censored) the subject has not experienced an event up to time
1 indicates observed (ground truth) event times.
Time-to-event models are conditional on covariates: the event time density function f (t|x), the hazard rate (risk score) function h(t|x) or the survival function , also known as the probability of failure occurring after time t, where F(t|x) is the cumulative density function. From standard survival function definitions [33], the relationship between these three characterizations is formulated as f (t|x) = h(t|x)S(t|x).
In practice, modern (often large) datasets are not homogeneous but composed of phenotipically heterogeneous subpopulations, i.e., subsets of observations that cluster according to both covariates and time-to-event similarities. In a clinical setting for instance, identification of, e.g., high-, medium- and low-risk subpopulations that are equipped with accurate estimates of time-to-event has the potential to result in a more cost effective way of targeting interventions, treatments or care delivery. We formulate an approach to jointly learn individualized time-to-event distributions and clusters informed by time-to-event profiles.
The Bayesian nonparametrics approach formulated below encourages latent representations to behave as a mixture of distributions, following a Dirichlet Process (DP) structure via a distribution matching approach. Further, we learn to cluster the latent space in a stochastic manner for which the number of clusters is unknown. To demonstrate the efficacy of our clustering algorithm, we also present a time-to-event prediction formulation, leveraging current state-of-the-art time-to-event prediction models. See the Supplementary Material for the list of variable definitions used in our formulation.
3.1 Clustering with Dirichlet Process
A DP is formally defined as and parametrized by the
and concentration parameter
0 [18]. With probability one [47]:
where represents a probability measure concentrated at
weights with statistics that depend on parameter
. The sequence
satisfies
1, with probability one, such that
[42]. Further, note that
represents the likelihood that
Discrete distribution G is suitable as a prior for mixture components in infinite mixture models [45]. Further, the stick-breaking process [47] that generates results in a mechanism that allows one to learn the number of mixture components (clusters) from data. In fact, the number of distinct atoms,
, has been shown to grow with the size of the data as O(log N) [1]. So motivated, we assume that data embedded in a latent space are distributed according to a mixture of distributions with parameters specified by the
, as described below.
Assuming exchangeable latent representations , we pro- pose generating event times following the generative process below
Figure 2: Illustration of Survival Clustering Analysis (SCA). The latent space has a mixture-of-distributions structure, illustrated as three mixture components . Observation x is mapped into its latent representation via a deterministic encod- ing
belonging to
, which is then used to stochastically predict (via sampling) the time-to-event via
where is a function that implicitly represents the conditional time-to-event density, f (t|x), specified as a neural network with parameters
. The source of stochasticity,
to a simple distribution
, uniform or Gaussian. The latent representation for the n-th observation,
is distributed according to
, where
is the mixture component membership indicator for
. Lastly, together with (3), p(c) in (2) represents an infinite mixture of Student’s t-distributions with
degrees of freedom and means
, each of which is drawn independently from the base probability measure
The Student’s t distribution in (3) is a general yet parametrically simple distribution, robust to outliers and amenable to efficient computations and gradient estimates. It has been widely used in machine learning for mixture modeling [45], clustering [52] and visualization [37]. Further, we formulate the t distribution according to the normal-inverse-gamma likelihood, where marginalizing out the variance yields a Student-t distribution, see [8] for details. Interestingly, as special cases, when is Cauchy distributed while for
approaches a Gaussian distribution.
The generative process above further requires learning a mapping function from covariates to latent space, with parameters
, that is globally consistent with the mixture model prior in (2) and (3), parameterized by
. In addition, we also need to learn the parameters
of the time-to-event generating function
. This specification, illustrated in Figure 2, constitutes the proposed Survival Clustering Analysis (SCA).
Note that unlike existing unsupervised and supervised autoencoding approaches [28, 32, 49], we do not seek to model the covariates, x. Rather, we make time-to-event predictions based on a latent representation specified as a function of observed covariates, required to be consistent with a mixture of distributions prior. Consequently, we need not specify a decoding arm to reconstruct the covariates, x.
In practice, learning the mixture component assignments and a potentially infinite number of mixture components with minibatches (stochastically) is challenging, because the former constitutes a discrete random variable and the latter requires keeping track of the number of non-empty mixture components during learning. To circumvent this, we learn the mixture component assignments probabilistically as
, and use a truncated representation of the DP formulation [9, 26], which for large enough truncation number, denoted as K, is virtually indistinguishable from a standard DP [26].
3.2 Latent-Space Representation
Following the conventional maximum likelihood formulation for mixture models [8], we can approximate the distributions for the mixture assignments and mixture proportions as follows
where is a variational parameter for expectation
M is the minibatch size and we have replaced
in (3) with the encoding of covariates into latent space,
(5) is not necessarily consistent with the DP in (2) and its stick-breaking prior,
, which from (1) should result in
where , which in practice is complicated by the need to sample from the mixture proportion weights
. In our implementation, instead of sampling from
, we use its expectation,
. Alternatively, we could also use a reparameterizable distribution such as the Kumaraswamy distribution, which is closely related to the Beta distribution as in [41]. However, we found that using expectations, which is common in variational formulations [9, 29], works well in practice.
In order to make in (5) and
in (6) consistent, we want their distributions to match, i.e., we seek to learn
, so the approximation
matches the desired stick breaking behavior of (6). For this purpose, we minimize
The KL Divergence between the two Dirichlet distributions q and p with respect to their corresponding parameters and
, has a desirable closed form formulation defined as
where is the digamma function and
Gamma function.
This loss function is used during learning to update and
. For
, the mixture proportions, we use a simple updating procedure akin to online expectation-maximization (EM) [10]. In particular, we update iteratively as
where 0 1 is the step size and we initialize
. In practice, we set
9; however,
can also be selected using grid search. The online approach in (9) is widely used to update global parameters in stochastic learning procedures. For instance, it has been used to learn the population mean and variance in batch normalization [25].
3.3 Time-to-Event Distributions
In addition to the clustered, mixture representation of the latent space, we also seek a high-performing time-to-event model that yields concentrated, accurate and calibrated time-to-event predictions, while accounting for censored event times (0). We borrow the accuracy objective from DATE [11] and the calibration objective from SFM [12]. Below we describe these objectives in the context of our formulation.
Accuracy Objective. The dataset D is split into two disjoint sets and
, where
and
represent censored and non-censored empirical distributions for these sets, respectively. We leverage the accuracy objective from DATE [11] formulated as
where has a simple distribution (uniform or Gaussian).
encourages that time-to-event samples from the model, evaluated on censored observations,
0, are larger than the censoring time, while close to the ground truth for non-censored (observed) events,
Calibration Objective. We desire that samples generated from the model match the empirical marginal distribution p(t). We borrow the calibration objective from SFM [12] defined over the set of distinct and ordered observed event times (censored and non-censored),
where ˆis formulated as:
and is the Heaviside step function. When evaluating the objective,
in (11), ˆ
is either a sample from the model, ˆ
, or an observed time ˆ
, for ˆ
, respectively. Expression ˆ
resents the point-estimate-based formulation of the Kaplan Meier estimator, see [12] for details.
3.4 Learning
For joint learning of all model parameters, and
, we optimize both the latent representation and time-to-event (accuracy and calibration) objectives. The complete objective function for the proposed Survival Cluster Analysis (SCA) model is
where 0 are hyper-parameters controlling the trade-off between accuracy and calibration objectives, relative to the clustering objective in (7). For simplicity and comparability with SFM, we set
1. The objective in (13) is optimized using stochastic gradient descent on minibatches from D. In practice,
is updated according to stochastic gradientdescent by optimizing the KL objective (7), and is initialized with K-means after pretraining (13) without the clustering objective. During inference, we assign a new observation,
, to a cluster by first evaluating
for k = 1, . . . ,K, then obtaining a hard assignment according to
The maximum number of mixture components K is fixed during learning. However, provided that the KL divergence (7) encourages mixture proportions to follow a stick-breaking process, the effective number of mixture components, i.e., those with non-empty observation assignments, will be smaller than K, thus effectively resulting in the model learning the number of mixture components. This is illustrated in Figure 3 and described below in the experiments. The number of degrees of freedom, is a hyperparameter, set to 1 in our experiments.
The comparisons presented below are made across a diverse range of six datasets, as summarized in Table 1. Refer to the Supplementary Material for all details concerning the experimental setup. Throughout the experiments, we set K = 25 and select {2, 3, 4, 8} via grid search cross-validation from the training sets. TensorFlow code to replicate experiments can be found at https: //github.com/paidamoyo/survival_cluster_analysis.
Table 1: Summary statistics of the datasets used in the experiments. The time range, , is noted in days except for seer for which time is measured in months.
Figure 3: Inferred clusters on the testing set of sleep dataset, with K = 25 and 3 where: (a) corresponding individual probability distribution
, are approximated according to (5), (b) joint t-SNE plot of centroids
with latent rep- resentation z and (c) density plot of inferred number of clusters K during training.
Datasets. Table 1 shows the summary statistics of the six datasets considered. The datasets are diverse in number of observations N, varying amounts of categorical (cat) and continuous covariates d, proportions of non-censored events, missingness rates in the covariate matrix, and time horizon
(measured in days, except for seer which is measured in months). Following information-theoretic data processing inequality conclusions from [39], demonstrating insignificant performance change relative to pre-imputation, we impute continuous and categorical covariates with the median and mode, respectively. In our experiments we do not convert time to a common scale but model it as is.
Publicly accessible datasets include: i) flchain: a study of nonclonal serum immunoglobin free light chains effects on survival time [15]. ii) support: investigates the survival time of seriously-ill hospitalized adults [34]. iii) seer: accessible from the Surveillance, Epidemiology, and End Results (SEER) Program. The dataset is preprocessed according to the details described in [46]. We restrict the dataset to a 10-year follow-up breast cancer subcohort.
The following datasets are available upon request:iv) ehr: a large study from the Duke University Health System centered around multiple inpatient visits due to comorbidities in patients with Type-2 diabetes [11]. v) sleep: a subset of the Sleep Heart Health Study (SHHS) [43], a multi-center cohort study implemented by the National Heart Lung & Blood Institute to determine the cardiovascular and other consequences of sleep-disordered breathing. We focus on the baseline clinical visit and aggregated demographics, medications and questionnaire data as covariates. vi) framingham: a subset (Framingham Offspring) of the longitudinal study of heart disease [7] dataset, initially for predicting 10-year risk for future coronary heart disease (CHD).
Clustering Baselines. We consider the standard K-means and CoxPH based SSC-Bair [6] as strong clustering baselines for SCA. We provide quantitative evaluations in terms of the logrank score [38], and qualitative visualization of the clustering-based KaplanMeier sub-population survival curves.
Time-to-Event Baselines. We compare SCA to the following time-to-event baselines: SFM [12], DATE [11], S-CRPS [5], CoxPH [14], MTLR [53] and DRAFT [11]. From these, SFM and DATE are key to our comparisons because we leverage components from their formulation into SCA; namely, the accuracy loss from DATE and the distribution matching loss from SFM. In that sense, we expect SCA to perform as good as SFM and DATE, but with the added benefit of producing clusters with distinct risk profiles. We present quantitative evaluations in terms of C-index, Calibration slope, Relative Absolute Error (RAE), and mean Coefficient of Variation (CoV). Details of these metrics are provided in the Supplementary Material.
4.1 Qualitative Results
Figure 3 shows for the sleep dataset a) estimated individualized cluster assignment probability distributions (rows) evaluated according to (5); b) t-SNE plots of the model inferred centroids, , as well as the individual latent representation
; and c) density plot of the inferred number of (non-empty) clusters K during training. See the Supplementary Material for similar figures for all the other datasets, where we also include corresponding Kaplan-Meier curves, as in Figure 1.
Interestingly, the cluster-specific covariate statistics for the Framingham dataset, which has the least number of covariates, are
Table 2: Inferred cluster specific covariate information on the testing set for the framingham dataset. The inferred cluster assignments are according to the corresponding individual probability distribution , approximated according to (5). Ranges in parentheses are 50% empirical ranges over (median) test-set predictions for the continuous and proportions for categorical covariates.
Figure 4: Inferred Cluster specific Kaplan-Meir Curves on the testing set of Framingham dataset, with The inferred clusters assignment is according to the corresponding individual probability distribution
approximated according to (5).
shown in Table 2 and are consistent with findings from the Framingham Heart Study [7], which identified high blood cholesterol and high blood pressure as major risk factors for cardiovascular disease.
We obtain the cluster specific Kaplan-Meir curves illustrated in Figure 4 with corresponding cluster specific covariate information shown in Table 2. The inferred individual cluster assignment is obtained according to the individual probability distribution , approximated according to (5). We consider curves above the population average low-risk while the curves below to be high-risk.
Therefore, our model identifies three high-risk clusters, indexed by 2, 5, 6:i) cluster 6 and cluster 5 have similar statistics, as they both consists disproportionately of diabetic individuals on hypertension medication with elevated total cholesterol (normal is below 200), high systolic blood pressure (normal is below 120), and noticeably low HDL (normal is greater than 60); ii) cluster 2, is also driven by age, which is expected, where about 40% of the population is on hypertension medication, and with the worst systolic blood pressure and cholesterol compared to other clusters; iii) lowerrisk clusters 1, 3, 4 are mostly comprised of females with normal levels of HDL, total cholesterol and systolic blood pressures; iv) cluster 0 represents the average statistics of the Framingham dataset, thus the survival curves directly follows the empirical population survival. Finally, note that the three high-risk clusters (2, 5 and 6) have a substantial over-representation of African Americans, known to have an increased risk for cardiovascular disease [7]. See the Supplementary Material for additional inferred cluster specific Kaplan-Meir curves on all datasets.
We demonstrate that by jointly learning clustering with respect to both the covariates x and predicted time-to-event t, our model SCA can identify high-, medium- and low-risk individuals , which is essential for clinical decision making. During inference, both the risk profile and individualized time-to-event can provide a comprehensive prediction mechanism for identifying cluster-based risk factors, cluster-based risk profiles and individualized time predictions. Further, the advantage of matching the empirical mixture distribution with a (truncated) DP yields sparse predictions of cluster assignment probabilities, , manifested as high confidence cluster assignments illustrated as a heatmap Figure 3(a).
Calibration Curves. We visually compare calibration curves from DATE, DRAFT, SCA, SFM, S-CRPS and CoxPH. Figure 5 shows the estimated populations-based model survival functions according to [12] and empirical Kaplan-Meier for the Framingham and sleep datasets. Error bars (shaded area) are calculated according
Figure 5: Survival function estimates for (a) Framingham and (b) sleep data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CoxPH). Error bars (shaded area) are calculated according to the Greenwood’s formula [23].
to the Greenwood’s formula [23]. For all datasets, SCA- and SFMestimated population survival functions closely match the empirical ground truth survival function, which is consistent with the high calibration slopes results in Table 3. See the Supplementary Material for additional calibration and survival function results on all datasets.
4.2 Quantitative Results
Below we describe performance metrics across all datasets and models. Specifically, calibration slope, mean CoV (coefficient of variation), C-index [24] and Relative Absolute Error [RAE, 53] provide a comprehensive evaluation, as they offer insights into consistency of time-to-event predictions, concentration of predicted distributions, pairwise ranking consistency, and accuracy of event time predictions, respectively. The results demonstrate that by jointly modeling the time-to-event and cluster assignments we obtain a better calibrated model, that is competitive in C-index, concentrated and accuracy of predictions. Table 3 shows the calibration slopes and RAE across all datasets and models. See the Supplementary Material for detailed RAE, mean CoV and C-index results.
Table 4 presents the clustering performances of the best performing K-means, SSC-Bair and SCA algorithms, measured in terms of the logrank score [38], for SSC-Bair and K-means we selected the best performing model from the set K = {2, 3, 4, 5, 6}.
Calibration Slope. For calibration we use the framework developed in SFM to evaluate the models [12]. An ideal calibration slope is 1, while a slope < 1 and slope > 1 indicates whether the model tends to underestimate or overestimate risk, respectively. The clustering objective in SCA augments the calibration objective we borrow from SFM, thus improving the calibration even for non-iid observations, such as Framingham and ehr, which are considered poorly calibrated, as illustrated in SFM. Given that SCA leverages the calibration objective of SFM, it is not surprising that both SCA and SFM are competitive, followed by DATE, S-CRPS, MTLR, CoxPH and lastly DRAFT. See Supplementary Material for qualitative calibration plots.
Relative Absolute Error (RAE). We compute RAE for both censored and non-censored events. In Table 3 we present the RAE for non-censored event times (1) for models that predict absolute event times, thus excluding scoring based models (CoxPH and MTLR). The results demonstrate that DATE, SFM and SCA (nonparametric) methods outperform DRAFT and S-CRPS (parametric) methods, which is expected since they all use a similar accuracy-aware objective function. For censored events (
0), RAE provides the lower bound error given the censored time provides tail information of p(t|x); parametric methods (DRAFT and S-CRPS) have small advantage over nonparametric methods (SFM, DATE and SCA). See the Supplementary Material for additional results on censored event times.
Concordance Index (C-index). C-index is a ranking metric that does not account for uncertainty in time-to-event predictions. Therefore to evaluate the time-to-event models (except CoxPH) in terms of C-index, we use point summaries of the individualized time-to-event distributions, specifically, ˆ, where
is a sample from the trained model,
on the test set. Apart from the small covariates, the very low event rate Framingham dataset and the small high event rate support dataset, none of the models have a clear advantage on the C-index metric. This is not surprising because C-index with very low event rate is heavily influenced by the censored observations. Note, for MTLR, although we can compute the C-index at prespecified thresholds, we are unable to compute a global C-index.
Coefficient of Variation (CoV) . Models that characterize the event time density function f (t|x) result in uncertainty-aware time-to-event predictions. In practice, it is highly desirable for a model to generate concentrated time-to-event predictions. The CoV (coefficient of variation) measures the dispersion in a distribution; a Cov > 1 indicates high variance, while CoV < 1 indicates low variance distributions. Cov results provided in the Supplementary Material demonstrate that DATE, SCA and SFM are consistently low-variance distributions, followed by S-CRPS and lastly DRAFT.
Table 3: Calibration slope and RAE metrics on test data.
Table 4: Logrank score and standard errors in parentheses. The best performing K-means and SSC-Bair models were selected from the set K = {2, 3, 4, 5, 6}.
We cannot compute CoV for both MTLR and CoxPH. CoxPH estimates risk score, and therefore cannot be evaluated on CoV. MTLR does not specify the conditional hazards, h(t|x), and thus we cannot recover f (t|x) = S(t|x)h(t|x).
Logrank Score. The logrank score is a nonparametric statistic that evaluates the similarity between a pair of survival functions, yielding high values for curves that are highly unlikely to be similar [38]. Further, the logrank statistic is especially powerful for measuring differences between survival functions that follow the Cox proportional hazard assumption, i.e, the survival functions do not cross. For K clusters, we compute pairwise comparisons. Table 4 demonstrates that our proposed SCA is the best performing method, followed by SS-Bair and lastly K-means. Interestingly, SCA is unable to recover any clustering structure from the ehr dataset, as it is a homogeneous population of Type-2 diabetes subjects, whereas K-means and SSC-Bair are always able to produce clusters (which may be misleading for homogeneous datasets). This supports the need to account for survival information when clustering survival datasets, as both SCA and SSC-Bair incorporate time information in their clustering approaches.
We have developed the first time-to-event model for inferring individualized risk-based cluster assignments, while jointly predicting the time-to-event. Leveraging a Bayesian nonparametric stick-breaking representation of the Dirichlet Process, we have presented a method for learning a clustering structure in a latent representation, for which the number of clusters is unknown. We have demonstrated the need to account for time information when clustering survival datasets. Our model identifies interpretable and phenotyopically heterogeneous subpopulations, which are critical in a clinical setting for identifying subjects with diverse risk profiles. Extensive experiments demonstrate that the joint modeling approach yields substantial performance gains in calibration and logrank scores, while remaining competitive in preserving pairwise ordering, predicting concentrated and accurate distributions. In the future, we plan to extend this work to account for locallyconsistent, calibrated and accurate predictions within identified subpopulations.
The authors would like to thank the anonymous reviewers for their insightful comments. This work was supported by NIH/NIBIB R01-EB025020.
[1] M. Abramowitz, I. A. Stegun, and R. H. Romer. Handbook of mathematical functions with formulas, graphs, and mathematical tables, 1988.
[2] E. Ahlqvist, P. Storm, A. Käräjämäki, M. Martinell, M. Dorkhan, A. Carlsson, P. Vikman, R. B. Prasad, D. M. Aly, P. Almgren, et al. Novel subgroups of adultonset diabetes and their association with outcomes: a data-driven cluster analysis of six variables. The lancet Diabetes & endocrinology, 2018.
[3] T. Ahmad, M. J. Pencina, P. J. Schulte, E. OâĂŹBrien, D. J. Whellan, I. L. Piña, D. W. Kitzman, K. L. Lee, C. M. OâĂŹConnor, and G. M. Felker. Clinical implications of chronic heart failure phenotypes defined by cluster analysis. Journal of the American College of Cardiology, 2014.
[4] A. M. Alaa and M. van der Schaar. Deep Multi-task Gaussian Processes for Survival Analysis with Competing Risks. In NeurIPS, 2017.
[5] A. Avati, T. Duan, K. Jung, N. H. Shah, and A. Ng. Countdown regression: Sharp and calibrated survival predictions. arXiv, 2018.
[6] E. Bair and R. Tibshirani. Semi-supervised methods to predict patient survival from gene expression data. PLoS biology, 2(4):e108, 2004.
[7] E. J. Benjamin, D. Levy, S. M. Vaziri, R. B. D’agostino, A. J. Belanger, and P. A. Wolf. Independent risk factors for atrial fibrillation in a population-based cohort: the framingham heart study. Jama, 1994.
[8] C. M. Bishop. Pattern recognition and machine learning. springer, 2006.
[9] D. M. Blei, M. I. Jordan, et al. Variational inference for dirichlet process mixtures. Bayesian analysis, 1(1):121–143, 2006.
[10] O. Cappé and E. Moulines. On-line expectation–maximization algorithm for latent data models. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 2009.
[11] P. Chapfuwa, C. Tao, C. Li, C. Page, B. Goldstein, L. Carin, and R. Henao. Adversarial time-to-event modeling. In ICML, 2018.
[12] P. Chapfuwa, C. Tao, L. Carin, and R. Henao. Survival function matching for calibrated time-to-event predictions. arXiv preprint arXiv:1905.08838, 2019.
[13] F. S. Collins and H. Varmus. A new initiative on precision medicine. New England journal of medicine, 2015.
[14] D. R. Cox. Regression models and life-tables. In Breakthroughs in statistics. Wiley Online Library, 1992.
[15] A. Dispenzieri, J. A. Katzmann, R. A. Kyle, D. R. Larson, T. M. Therneau, C. L. Colby, R. J. Clark, G. P. Mead, S. Kumar, L. J. Melton, et al. Use of nonclonal serum immunoglobulin free light chains to predict overall survival in the general population. In Mayo Clinic Proceedings, 2012.
[16] U. Djuric, G. Zadeh, K. Aldape, and P. Diamandis. Precision histology: how deep learning is poised to revitalize histomorphology for personalized cancer care. NPJ precision oncology, 2017.
[17] M. B. Eisen, P. T. Spellman, P. O. Brown, and D. Botstein. Cluster analysis and display of genome-wide expression patterns. Proceedings of the National Academy of Sciences, 95(25):14863–14868, 1998.
[18] T. S. Ferguson. A bayesian analysis of some nonparametric problems. The annals of statistics, 1973.
[19] T. Fernández, N. Rivera, and Y. W. Teh. Gaussian processes for survival analysis. In NeurIPS, 2016.
[20] S. Fotso. Deep neural networks for survival analysis based on a multi-task framework. arXiv, 2018.
[21] S. Gaynor and E. Bair. Identification of relevant subtypes via preweighted sparse clustering. Biostatistics, 2013.
[22] X. Glorot and Y. Bengio. Understanding the difficulty of training deep feedforward neural networks. In AISTATS, 2010.
[23] M. Greenwood et al. A report on the natural duration of cancer. A Report on the Natural Duration of Cancer., 1926.
[24] F. E. Harrell, K. L. Lee, R. M. Califf, D. B. Pryor, and R. A. Rosati. Regression modelling strategies for improved prognostic prediction. Statistics in medicine, 1984.
[25] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In ICML, 2015.
[26] H. Ishwaran and L. F. James. Gibbs sampling methods for stick-breaking priors. Journal of the American Statistical Association, 2001.
[27] H. Ishwaran and M. Lu. Random survival forests. Wiley StatsRef: Statistics Reference Online, 2008.
[28] Z. Jiang, Y. Zheng, H. Tan, B. Tang, and H. Zhou. Variational deep embedding: An unsupervised and generative approach to clustering. In IJCAI, 2017.
[29] M. I. Jordan, Z. Ghahramani, T. S. Jaakkola, and L. K. Saul. An introduction to variational methods for graphical models. Machine learning, 1999.
[30] J. L. Katzman, U. Shaham, A. Cloninger, J. Bates, T. Jiang, and Y. Kluger. Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network. BMC medical research methodology, 2018.
[31] D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
[32] D. P. Kingma and M. Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
[33] D. G. Kleinbaum and M. Klein. Survival analysis. Springer, 2010.
[34] W. A. Knaus, F. E. Harrell, J. Lynn, L. Goldman, R. S. Phillips, A. F. Connors, N. V. Dawson, W. J. Fulkerson, R. M. Califf, N. Desbiens, et al. The SUPPORT prognostic model: objective estimates of survival for seriously ill hospitalized adults. Annals of internal medicine, 1995.
[35] C. Lee, W. R. Zame, J. Yoon, and M. van der Schaar. Deephit: A deep learning approach to survival analysis with competing risks. In AAAI, 2018.
[36] C. Lee, W. R. Zame, A. M. Alaa, and M. van der Schaar. Temporal quilting for survival analysis. In AISTATS, 2019.
[37] L. v. d. Maaten and G. Hinton. Visualizing data using t-sne. Journal of machine learning research, 2008.
[38] N. Mantel. Evaluation of survival data and two new rank order statistics arising in its consideration. Cancer Chemother Rep, 1966.
[39] X. Miscouridou, A. Perotte, N. Elhadad, and R. Ranganath. Deep survival analysis: Nonparametrics and missingness. In Machine Learning for Healthcare Conference, 2018.
[40] S. C. Mouli, L. Teixeira, B. Ribeiro, and J. Neville. Deep lifetime clustering. arXiv, 2019.
[41] E. Nalisnick and P. Smyth. Stick-breaking variational autoencoders. In ICLR, 2017.
[42] J. Pitman. Poisson–dirichlet and gem invariant distributions for split-and-merge transformations of an interval partition. Combinatorics, Probability and Computing, 2002.
[43] S. F. Quan, B. V. Howard, C. Iber, J. P. Kiley, F. J. Nieto, G. T. O’connor, D. M. Rapoport, S. Redline, J. Robbins, J. M. Samet, et al. The sleep heart health study: design, rationale, and methods. Sleep, 1997.
[44] R. Ranganath, A. Perotte, N. Elhadad, and D. Blei. Deep survival analysis. In Machine Learning for Healthcare Conference, 2016.
[45] C. E. Rasmussen. The infinite gaussian mixture model. In NeurIPS, 2000.
[46] L. A. G. Ries, J. L. Young Jr, G. E. Keel, M. P. Eisner, Y. D. Lin, and M.-J. D. Horner. Cancer survival among adults: US SEER program, 1988–2001. Patient and tumor characteristics SEER Survival Monograph Publication, 2007.
[47] J. Sethuraman. A constructive definition of dirichlet priors. Statistica sinica, 1994.
[48] S. J. Shah, D. H. Katz, S. Selvaraj, M. A. Burke, C. W. Yancy, M. Gheorghiade, R. O. Bonow, C.-C. Huang, and R. C. Deo. Phenomapping for novel classification of heart failure with preserved ejection fraction. Circulation, 2015.
[49] P. Vincent, H. Larochelle, I. Lajoie, Y. Bengio, and P.-A. Manzagol. Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion. Journal of machine learning research, 2010.
[50] L. J. Wei. The accelerated failure time model: A useful alternative to the cox regression model in survival analysis. Statistics in Medicine, 1992.
[51] E. Xia, X. Du, J. Mei, W. Sun, S. Tong, Z. Kang, J. Sheng, J. Li, C. Ma, J. Dong, et al. Outcome-driven clustering of acute coronary syndrome patients using multi-task neural network with attention. arXiv preprint arXiv:1903.00197, 2019.
[52] J. Xie, R. Girshick, and A. Farhadi. Unsupervised deep embedding for clustering analysis. In ICML, 2016.
[53] C.-N. Yu, R. Greiner, H.-C. Lin, and V. Baracos. Learning patient-specific cancer survival distributions as a sequence of dependent regressors. In NeurIPS, 2011.
[54] Q. Zhang and M. Zhou. Nonparametric Bayesian lomax delegate racing for survival analysis with competing risks. In NeurIPS, 2018.
Refer to Table 1 for summary descriptions of notation used in the SCA formulation.
In all experiments, SCA, SFM, DATE, DRAFT and S-CRPS are specified in terms of two-layer MLPs of 50 hidden units with Rectified Linear Unit (ReLU) activation functions, batch normalization [3] and apply dropout of p = 0.2 on all layers. We set the minibatch size to M = 350 and use the Adam [4] optimizer with the following hyperparameters: learning rate 3 , first moment 0.9, second moment 0.99, and epsilon 1
10
. We initialize all the network weights according to Xavier [2]. SFM and DATE inject noise in all layers, see [1] for more details; while SCA injects noise only in the last layer. Datasets are split into training, validation and test sets as 80%, 10% and 10% partitions, respectively, stratified by non-censored event proportion. The validation set is used for early stopping and learning model hyperparameters. All models are trained using one NVIDIA P100 GPU with 16GB memory.
See Table 2 for additional quantitative evaluations on C-index, mean CoV and RAE.
The model calibration and survival plots for datasets support, flchain, sleep, seer, framingham, and ehr are shown in Figures 7 - 12.
We provide all the qualitative visualization of the latent-space representation, namely, a) estimated individualized cluster assignment probability distributions; b) t-SNE plots of both the centroids with z; c) cluster-specific Kaplan Meir curves. Refer to Figures 1 - 6 for SCA results on framingham, support, flchain, sleep, seer and ehr datasets respectively.
Figure 1: Inferred clusters on the testing set of framingham dataset, with K = 25 and 8 with corresponding individual probability distribution
Figure 2: Inferred clusters on the testing set of support dataset, with with corresponding individual probability distribution
Figure 3: Inferred clusters on the testing set of flchain dataset, with with corresponding individual probability distribution
Figure 4: Inferred clusters on the testing set of sleep dataset, with with corresponding individual probability distribution
Figure 5: Inferred clusters on the testing set of seer dataset, with K = 25 and 2 with corresponding individual probability distribution
Figure 6: Inferred clusters on the testing set of ehr dataset, with K = 25 and 2 with corresponding individual probability distribution
Figure 7: Calibration (left) and Survival function estimates (right) for support data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CoxPH).
Figure 8: Calibration(left) and Survival function estimates (right) for flchain data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CPH).
Figure 9: Calibration (left) and Survival function estimates (right) for sleep data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CPH).
Figure 10: Calibration(left) and Survival function estimates (right) for framingham data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CPH).
Figure 11: Calibration(left) and Survival function estimates (right) for seer data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CPH).
Figure 12: Calibration(left) and Survival function estimates (right) for ehr data. Ground truth (Empirical) is compared to predictions from six models (DATE, DRAFT, SCA (our proposed model), SFM, S-CRPS and CPH).
[1] P. Chapfuwa, C. Tao, C. Li, C. Page, B. Goldstein, L. Carin, and R. Henao. Adversarial time-to-event modeling. In ICML, 2018.
[2] X. Glorot and Y. Bengio. Understanding the difficulty of training deep feedforward neural networks. In AISTATS, 2010.
[3] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In ICML, 2015.