b

DiscoverSearch
About
My stuff
Dipole: Diagnosis Prediction in Healthcare via Attention-based Bidirectional Recurrent Neural Networks
2017·arXiv
ABSTRACT
ABSTRACT

Predicting the future health information of patients from the historical Electronic Health Records (EHR) is a core research task in the development of personalized healthcare. Patient EHR data consist of sequences of visits over time, where each visit contains multiple medical codes, including diagnosis, medication, and procedure codes. Te most important challenges for this task are to model the temporality and high dimensionality of sequential EHR data and to interpret the prediction results. Existing work solves this problem by employing recurrent neural networks (RNNs) to model EHR data and utilizing simple atention mechanism to interpret the results. However, RNN-based approaches sufer from the problem that the performance of RNNs drops when the length of sequences is large, and the relationships between subsequent visits are ignored by current RNN-based approaches. To address these issues, we propose Dipole, an end-to-end, simple and robust model for predicting patients’ future health information. Dipole employs bidirectional recurrent neural networks to remember all the information of both the past visits and the future visits, and it introduces three atention mechanisms to measure the relationships of diferent visits for the prediction. With the atention mechanisms, Dipole can interpret the prediction results efectively. Dipole also allows us to interpret the learned medical code representations which are confrmed positively by medical experts. Experimental results on two real world EHR datasets show that the proposed Dipole can signifcantly improve the prediction accuracy compared with the state-of-the-art diagnosis prediction approaches and provide clinically meaningful interpretation.

Tis work was mostly done when the frst author was an intern in Xerox. Tis work was done when the ffh author was part of Xerox.

CCS CONCEPTS

Information systems  →Data mining; Applied computing →Health informatics;

KEYWORDS

Healthcare informatics, bidirectional recurrent neural networks, atention mechanism

ACM Reference format:

Fenglong Ma, Radha Chita, Jing Zhou, Qanzeng You, Tong Sun, and Jing Gao. 2017. Dipole: Diagnosis Prediction in Healthcare via Atention-based Bidirectional Recurrent Neural Networks. In  Proceedings of KDD’17, August13–17, 2017, Halifax, NS, Canada., , 10 pages.

image

Electronic Health Records (EHR), consisting of longitudinal patient health data, including demographics, diagnoses, procedures, and medications, have been utilized successfully in several predictive modeling tasks in healthcare [9–11, 31]. EHR data are temporally sequenced by patient medical visits that are represented by a set of high dimensional clinical variables (i.e., medical codes). One critical task is to predict the future diagnoses based on patient’s historical EHR data, i.e., diagnosis prediction. When predicting diagnoses, each patient’s visit and the medical codes in each visit may have varying importance. Tus, the most important and challenging issues in diagnosis prediction are: How to correctly model such temporal and high dimensional EHR data to signifcantly improve the performance of prediction; How to reasonably interpret the importance of visits and medical codes in the prediction results.

In order to model sequential EHR data, recurrent neural networks (RNNs) have been employed in the literature for deriving accurate and robust representations of patient visits in diagnosis prediction task [10, 11]. RETAIN [11] and GRAM [10] are two state-of-the-art models utilizing RNNs for predicting the future diagnoses. RETAIN applies an RNN with reverse time ordered EHR sequences, while GRAM uses an RNN when modeling time ordered patient visits. Both models achieve good prediction accuracy. However, they are constrained by the forgetfulness associated with models using RNNs, i.e. RNNs cannot handle long sequences efectively. Te predictive power of these models drops signifcantly when the length of the patient visit sequences is large. Bidirectional recurrent neural networks (BRNNs) [24], which can be trained using all available input information in the past and future, have been used to alleviate the efect of the long sequence problem, and improve the predictive performance.

However, it is infeasible to interpret the outputs of models incorporating either RNNs or BRNNs. Interpretability is crucial in the healthcare domain, as it can lead to the identifcation of potential risk factors and the design of suitable intervention mechanisms. Non-temporal models such as Med2Vec [9] generate easily interpretable low-dimensional representations of the medical codes, but do not account for the temporal nature of the EHR data.

To model the temporal EHR data and interpret the prediction results simultaneously, atention-based neural networks can be applied, which aim to learn the relevance of the data samples to the task. For example, RETAIN [11] employs location-based atention to predict the future diagnosis. It calculates the atention weights for a visit at time t, using the medical information in the current visit and the hidden state of the recurrent neural network at time t, to predict the visit at timet+1. It ignores the relationships between all the visits from time 1 to time t. We believe that accounting for all the past visit information may help the predictive models to improve the accuracy and provide beter interpretation.

To tackle all the aforementioned issues and challenges, we propose an efcient and accurate diagnosis prediction model(Dipole) using atention-based bidirectional recurrent neural networks for learning low-dimensional representations of the patient visits, and employ the learned representations for future diagnosis prediction. Te learned representations are easily interpretable, and can also be used to learn how important each visit is to the future diagnosis prediction. Specifcally, it frst embeds the high dimensional medical codes (i.e., clinical variables) into a low code-level space, and then feeds the code representations into an atention-based bidirectional recurrent neural network to generate the hidden state representation. Te hidden representation is fed through a sofmax layer to predict the medical codes in future visits. We experiment with three types of atention mechansims: (i) location-based, (ii) general, and (iii) concatenation-based, to calculate the atention weights for all the prior visits for each patient. Tese mechanisms model the inter-visit relationships, where the atention weights represent the importance of each visit.

We demonstrate that the proposed Dipole achieves signifcantly higher prediction accuracy when compared to the state-of-the-art approaches in diagnosis prediction, using two datasets derived from Medicaid claims data. A case study is conducted to show that the proposed model accurately assigns varying atention weights to past visits. We evaluate the interpretability of the learned representations through qualitative analysis. Finally, we illustrate the reasonableness of employing bidirectional recurrent neural networks to model temporal patient visits. In summary, our main contributions are as follows:

We propose Dipole, an end-to-end, simple and robust model to accurately predict the future visit information and reasonably interpret the prediction results, without depending on any expert medical knowledge.

Dipole models patient visit information in a time-ordered and reverse time-ordered way and employs three atention mechanisms to calculate the weights for previous visits.

We empirically show that the proposedDipole outperforms existing methods in diagnosis prediction on two large real world EHR datasets.

We analyze the experimental results with clinical experts to validate the interpretability of the learned medical code representations.

Te rest of this paper is organized as follows: In Section 2, we discuss the connection of the proposed approaches to related work. Section 3 presents the details of the proposed Dipole. Te experimental results are presented in Section 4. Section 5 concludes the paper.

Tis section reviews the existing work for mining ElectronicHealthcare Records (EHR) data. In particular, we focus on the state-of-the-art models on diagnosis prediction task. We also introduce some work using atention mechanisms.

2.1 EHR Data Mining

Mining EHR data is a hot research topic in healthcare informatics. Te tackled tasks include electronic genotyping and phenotyping [5, 16, 19, 32], disease progression [12, 26, 27, 33], adverse drug event detection [21], diagnosis prediction [9–11, 25, 31], and so on. In most tasks, deep learning models can signifcantly improve the performance. Recurrent neural networks (RNNs) can be used for modeling multivariate time series data in healthcare with missing values [6, 18]. Convolutional neural networks (CNNs) are used to predict unplanned readmission [23] and risk [7] with EHR. Stacked denoising autoencoders (SDAs) are employed to detect the characteristic paterns of physiology in clinical time series data [5].

Diagnosis prediction is an important and difcult task in healthcare. Med2Vec [9] aims to learn the representations of medical codes, which can be used to predict the future visit information. Tis methodignores long-term dependencies of medical codes among visits. RETAIN [11] is an interpretable predictive model, which employs reverse time atention mechanism in an RNN for binary prediction task. GRAM [10] is a graph-based atention model for healthcare representation learning, which uses medical ontologies to learn robust representations and an RNN to model patient visits. Both RETAIN and GRAM apply atention mechanisms and improve the prediction performance.

Compared with the aforementioned predictive models, the proposed approaches not only employ bidirectional neural networks when modeling visit information but also design diferent atention mechanisms to assign diferent weights for the past visits. Relying on these two properties, the proposed Dipole can improve the prediction performance signifcantly and interpret the meanings of medical codes reasonably.

2.2 Attention-based Neural Networks

Atention-based neural networks have been successfully used in many tasks [1, 2, 13, 14, 17, 20, 28, 29]. Specifcally for neural machine translation [2, 20], given a sentence in the original language (i.e., original space), RNNs were adopted to generate the word representations in the sentence  h1, · · · ,h|S |, where |S|is the number of words in this sentence. In order to fnd the t-th word in the target language (or target space), a weight  αti, i.e., atention score, is assigned to each word in the original language. Ten, a context vector  ct = �|S |i=1 αtihiis calculated to predict the t-th word in the target language. However, diagnosis prediction task is diferent from the language translation task as all the visits for each patient are in the same space.

In this section, we frst introduce the structure of EHR data and some basic notations. Ten we describe the details of the proposed Dipole neural network. Finally, we describe the interpretation for the learned code representations and visit representations.

3.1 Basic Notations

We denote all the unique medical codes from the EHR data as c1,c2, · · · ,c|C| ∈ C, where |C|is the number of unique medical codes. Assuming there are N patients, the n-th patient has  T (n)visit records in the EHR data. Te patient can be represented by a sequence of visits  V1,V2, · · · ,VT (n). Each visit  Vt, containing a subset of medical codes (Vt ⊆ C), is denoted by a binary vector xt ∈ {0, 1}|C|, where the i-th element is 1 if  Vtcontains the code ci. Each diagnosis code can be mapped to a node of the International Classifcation of Diseases (ICD-9)1, and each procedure code can be mapped to a node in the Current Procedural Terminology (CPT)2. Below we use a simple example to illustrate the problem.

Tere are two diagnosis codes: 250 (Diabetes mellitus) and 254 (Diseases of thymus gland), and one procedure code 11720 (Debride nail, 1-5) in the whole dataset, i.e., |C| = 3. If the medical codes in the t-th patient visit are 250 and 254, then  xt = [1, 1, 0].

Both ICD-9 and CPT systems are coded hierarchically, which means that each medical code has a “parent”, i.e., category label. For example, the diagnosis codes 250 and 254 belong to the same category Diseases of other endocrine glands, and the procedure code 11720 is in the category Surgical procedures on the nails. Tus, each visit  Vthas a corresponding coarse-grained category representation  yt ∈ {0, 1}|G |, where |G|is the unique number of categories. In the above example,  |G| = 2, and yt = [1, 0]. For simplicity, we describe the proposed algorithm for a single patient and drop the superscript (n) when it is unambiguous. Te input of the proposed Dipole model is a time-ordered sequence of patient visits.

3.2 Model

Te goal of the proposed algorithm is to predict the (t + 1)-th visit’s category-level medical codes. Figure 1 shows the high-level overview of the proposed model. Given the visit information from time 1 to t, the i-th visit’s medical codes  xican be embedded into a vector representationvi. Te vectorviis fed into the Bidirectional Recurrent Neural Network (BRNN) [24], which outputs a hidden statehias the representation of the i-th visit. Along with the set of hidden states  {hi}t−1i=1, we are able to compute the relative impor- tance vector  αtfor the current visit t. Subsequently, a context state ctis computed from the relative importance  αt and {hi}t−1i=1. Tisprocedure is known as atention model [2], which will be detailed in the following sections. Next, from the context state  ct and thecurrent hidden state  ht, we can obtain an atentional hidden state ˜ht, which is used to predict the category-level medical codes appearing in the (t +1)-th visit, i.e., ˆyt. Te proposed neural network can be trained end-to-end.

image

Figure 1: Te Proposed Dipole Model.

image

Given a visit  xi ∈ {0, 1}|C|, we can obtain its vector representation vi ∈ Rm as follows:

image

where m is the size of embedding dimension,  Wv ∈ Rm×|C| is theweight matrix of medical codes, and  bc ∈ Rm is the bias vector. ReLU is the rectifed linear unit defned as ReLU(v) = max(v, 0), where max() applies element-wise to vectors. Te reason we employ the rectifed linear unit as the activation function is that ReLU enables the learned vector representations to be interpretable [9].

image

Recurrent Neural Networks (RNNs) provide a very elegant way of modeling sequential healthcare data [10, 11]. However, one drawback of RNNs is that the prediction performance will drop when the length of the sequence is very large [24]. In order to overcome this drawback, we employ Bidirectional Recurrent Neural Networks (BRNNs) in the proposed model which can be trained using all the available input visits’ information from two directions to improve the prediction performance. Note that we use “RNNs” to denote any Recurrent Neural Networks variant dealing with the vanishing gradient problem [3], such as Long-Short Term Memory (LSTM) [15] and Gated Recurrent Unit (GRU) [8]. In our implementation, we use GRU to adaptively capture dependencies among patient visit information.

A BRNN consists of a forward and backward RNN. Te forward RNN −→freads the input visit sequence from  x1 to xTand calculates a sequence of  forward hidden states (−→h 1, · · · , −→h T ) (−→h i ∈ Rp and pis the dimensionality of hidden states). Te backward RNN ←−f readsthe visit sequence in the reverse order, i.e., from  xT to x1, resulting in a sequence of  backward hidden states (←−h 1, · · · , ←−h T ) (←−h i ∈ Rp).By concatenating the forward hidden state −→h iand the backward one ←−h i, we can obtain the fnal latent vector representation ashi =[−→h i; ←−h i]⊤ (hi ∈ R2p). Note that the future visit information is only used when training the model. Only the past visit information is provided to predict the future visits during the testing phase.

image

In diagnosis prediction task, the fnal goal is to predict the category-level medical codes of the (t + 1)-th visit, i.e.,  yt, according to the visits from  x1 to xt. Te output of the  t-th visit xt (ht) is the estimated vector representation of the (t + 1)-th visit. However, it may contain partial visit information to be predicted. Tus, how to derive a context vector  ctthat captures relevant information to help predict the future visit  ytis the key issue. Tere are three methods that can be used to compute the context vector  ct :

Location-based Atention. A location-based atention function is to calculate the weights solely from the current hidden state  hias follows:

image

where  Wα ∈ R2p and bα ∈ Rare the parameters to be learned. According to Eq. (2), we can obtain an atention weight vector  αtusing sofmax function as follows:

image

Ten the context vector  ct ∈ R2p can be calculated based on the weights obtained from Eq. (3) and the hidden states fromh1 toht−1as follows:

image

Since location-based atention mechanism only considers each individual hidden state information, it does not capture the relationships between the current hidden state and all the previous hidden states. To utilize the information from all the previous hidden states, we adopt the following two atention mechanisms in the proposed Dipole .

General Atention. An easy way to capture the relationship between  ht and hi (1 ⩽ i ⩽ t −1) is using a matrix  Wα ∈ R2p×2p,and calculating the weight as:

image

and the context vector  ctcan be obtained using Eq. (3) and Eq. (4).

Concatenation-based Atention. Another way to calculate the context vector  ctis using a multi-layer perceptron (MLP) [2]. We frst concatenate the current hidden state  hsand the previous state hi, and then a latent vector can be obtained by multiplying a weight matrix  Wα ∈ Rq×4p, where qis the latent dimensionality. We select tanh as the activation function. Te atention weight vector is generated as follows:

image

where  vα ∈ Rq is the parameter to be learned, and we can obtain the context vector  ctwith Eq. (3) and Eq. (4).

Diagnosis Prediction

Given the context vectorctand the current hidden stateht , we em-ploy a simple concatenation layer to combine the information from both vectors to generate an atentional hidden state as follows:

image

where  Wc ∈ Rr×4p is the weight matrix. Te atentional vector ˜htis fed through the sofmax layer to produce the (t + 1)-th visit information defned as:

image

where  Ws ∈ R|G |×r and bs ∈ R|G | are the parameters to be learned.

Objective Function

Based on Eq. (8), we use the cross-entropy between the ground truth visit information  ytand the predicted visit ˆytto calculate the loss for all the patients as follows:

image

3.3 Interpretation

In healthcare, the interpretability of the learned representations of medical codes and visits is important. We need to understand the clinical meaning of each dimension of medical code representations, and analyze which visits are crucial to the prediction.

Since the proposed model is based on atention mechanisms, it is easy to fnd the importance of each visit for prediction by analyzing the atention scores. For the t-th prediction, if the atention score  αtiis large, then the probability of the (i + 1)-th visit information related to the current prediction is high. We employ the simple method proposed in [9] to interpret the code representations. We frst use ReLU(W ⊤v ), a non-negative matrix, to represent the medical codes. Ten we rank the codes by values in a reverse order for each dimension of the hidden state vector. Finally, the top k codes with the largest values are selected as follows:

image

whereW ⊤v [:,i]represents the i-th column or dimension ofW ⊤v . Byanalyzing the selected medical codes, we can obtain the clinical interpretation of each dimension. Detailed examples and analysis are given in Section 4.4 and 4.5.

In this section, we evaluate the performance of the proposedDipole model on two real world insurance claims datasets, compare its performance with other state-of-the-art prediction models, and show that it yields higher accuracy.

4.1 Data Description

Te datasets we used in the experiments are the Medicaid claims and the Diabetes claims.

Te Medicaid Dataset

Our frst dataset consists of Medicaid claims3 over the year 2011. It consists of data corresponding to 147, 810 patients, and 1, 055, 011 visits. Te patient visits were grouped by week, and we excluded patients who made less than two visits.

Te Diabetes Dataset

Te Diabetes dataset consists of Medicaid claims over the years 2012 and 2013, corresponding to patients who have been diagnosed with diabetes (i.e. Medicaid members who have the ICD-9 diagnosis code 250.xx in their claims). It contains data corresponding to 22, 820 patients with 466, 732 visits. Te patient visits were aggregated by week, and excluded patients who made less than fve visits.

For both datasets, each visit information includes the ICD-9 diagnosis codes and procedure codes, categorized in accordance with the Current Procedural Terminology (CPT). Table 1 lists more details about the two datasets.

Table 1: Statistics of Diabetes and Medicaid Dataset.

image

4.2 Experimental Setup

In this subsection, we frst describe the state-of-the-art approaches for EHR representation learning and diagnosis prediction which are used as baselines, and then outline the measures used for evaluation. Finally, we introduce the implementation details.

image

To validate the performance of the proposed model for diagnosis prediction task, we compare it with several state-of-the-art models. We select three existing approaches as baselines4:

Med2Vec [9]. Med2Vec, which follows the idea of Skip-gram [22], is a simple and robust algorithm to efciently learn medical code representations and predict the medical codes appearing in the following visit based on the current visit information.

RETAIN [11]. RETAIN is an interpretable predictive model in healthcare with reverse time atention mechanism, a two-level neural atention model. It can fnd infuential past visits and important medical codes within those visits. Since the original RETAIN is used for binary prediction task, we change the fnal sofmax function for satisfying multiple variable prediction, i.e., diagnosis prediction.

RNN. We frst embed visit information into vector representations according to Eq. (1), then feed this embedding to the GRU. Te hidden states produced by the GRU are directly used to predict the medical codes of the (t + 1)-th visit using sofmax according to Eq. (8). All the parameters are trained together with the GRU.

image

Since all the atention mechanisms proposed in Section 3.2 can be used for RNN model, we propose three variants of RNN as follows:

RNNl. We add location-based atention model into RNN. Te atention scores are calculated by Eq. (2). Ten we can obtain the context vectors according to Eq. (4). Based on the context vectors, we can generate atention hidden states using Eq. (7). Finally, we can predict the medical codes of the (t + 1) visit using Eq. (8).

RNNд. RNNдis similar to  RNNl, but uses general atention model, i.e., Eq. (5), to calculate atention scores.

RNNc. RNNcuses concatenation-based atention mechanism (Eq. (6)) to calculate atention weights.

Te proposed Dipole model is a general framework for predicting diagnoses in healthcare. We show the performance of the following four approaches in the experiments.

Dipole−. Tis model only uses the hidden states generated by BRNN to predict the next visit information, i.e., without employing any atention mechanisms.

Dipolel. It is based on location-based atention mechanism with Eq. (2).

Dipoleд. Dipoleдuses general atention model when calculating the context vectors, i.e., Eq. (5).

Dipolec. Similar to  Dipolel and Dipoleд, Dipolecemploys concatenation based atention mechanism (Eq. (6)) in the predictive model.

Evaluation Strategies

To evaluate the performance of predicting future medical codes for each method, we use two measures: accuracy and accuracy@k . Accuracy is defned as the correct medical codes ranked in top k divided by  |yt |, where |yt |is the number of medical codes in the (t + 1)-th visit, and  k equals to |yt |. Accuracy@k is defned as the correct medical codes in top k divided by min(k, |yt |). In ourexperiments, we vary k from 5 to 30.

Implementation Details

We implement all the approaches with Teano 0.7.0 [4]. For training models, we use Adadelta [30] with a mini-batch of 100 patients5. We randomly divide the dataset into the training, validation and testing set in a 0.75:0.1:0.15 ratio. Te validation set is used to determine the best values of parameters. We also use regularization (l2norm with the coefcient 0.001) and drop-out strategies (the drop-out rate is 0.5) for all the approaches. In the experiments, we set the same m = 256, p = 256 and q = 128 for baselines and our approaches. We perform 100 iterations and report the best performance for each method.

4.3 Results of Diagnosis Prediction

Table 2 shows the accuracy of the proposed approaches and baselines on both Diabetes and Medicaid datasets for the diagnosis prediction task. #C represents the average number of correct predictions. Te number of visits or predictions in the test set is 65,975 in the Diabetes dataset, and 136,023 in the Medicaid dataset.

Table 2: Te Accuracy of Diagnosis Prediction Task.

image

In Table 2, we can observe that the accuracy of the proposed approaches, including Dipole and RNN variants, is higher than that of baselines on the Diabetes dataset. Since most medical codes are about diabetes, Med2Vec can correctly learn vector representations on the Diabetes dataset. Tus, Med2Vec achieves the best results among the three baselines. For the Medicaid dataset, the accuracy of RETAIN is beter than that of Med2Vec. Te reason is that there are many diseases in the Medicaid dataset, and the categories of medical codes are more than those on the Diabetes dataset. In this case, atention mechanism can help RETAIN to learn reasonable parameters and make correct prediction.

Te accuracy of RNN is the lowest among all the approaches on both datasets. Tis is because the prediction of RNN mostly depends on the recent visits’ information. It cannot memorize all the past information. However, RETAIN and the proposed RNN variants,  RNNl, RNNд and RNNc, can fully take all the previous visit information into consideration, assign diferent atention scores for past visits, and achieve beter performance when compared to RNN.

Since most of the visits on the Diabetes dataset are related to diabetes, it is easy to predict the medical codes in the next visit according to the past visit information. RETAIN uses a reverse time atention mechanism for prediction, which will decrease the prediction performance compared with the approaches using a time ordered atention mechanism. Tus, the performance of the three proposed RNN variants is beter than that of RETAIN. However, the accuracy of RETAIN is beter than the proposed  RNN variants’as the data are about diferent diseases on the Medicaid dataset. Using the reverse time atention mechanism can help us to learn the correct relationships among visits.

Both RNN and the proposed  Dipole−do not use any atention mechanism, but the accuracy of  Dipole−is higher than that of RNN on both the Diabetes and Medicaid dataset. It shows that modeling visit information from two directions can improve the prediction performance. Tus, it is reasonable to employ bidirectional recurrent neural networks for diagnosis prediction task.

Te proposed  Dipolec and Dipolelcan achieve the best performance on the Diabetes and Medicaid dataset respectively, which shows that both modeling visits from two directions and assigning a diferent weight to each visit can improve the accuracy for diagnosis prediction task in healthcare. On the Diabetes dataset, Dipolel and Dipolecoutperform all the baselines and the proposed RNN variants. On the Medicaid dataset, the performance of all the three proposed approaches,  Dipolel , Dipoleд and Dipolec is beterthan that of baselines and RNN variants.

Table 3 shows the experimental results with the accuracy@k measurement on the Diabetes and Medicaid dataset separately. We can observe that as k increases, the performance of all the approaches improves, but the proposed Dipole approaches show superior predictive performance, demonstrating their applicability in predictive healthcare modeling. In Table 3, RETAIN can achieve comparable performance with the proposedapproaches on the Medicaid dataset. Te overall performance of location-based atention methods,  Dipolel and RNNl, is beter than that of other methods, which indicates that location-based atention performs well on this dataset. RETAIN also uses location-based atention mechanism. Tus, it can obtain high accuracy.

4.4 Case Study

To demonstrate the beneft of applying atention mechanisms in diagnosis prediction task, we analyze the atention weights learned from one of the proposed approach  Dipolec, which uses concatenation based atention mechanism. Figure 2 shows a case study for predicting the medical code in the sixth visit (y5) based on the previous visits on the Diabetes dataset. Te concatenation-based atention weights are calculated for the visits from the second visit to the ffh visit according to the hidden states  h1, h2, h3 and h4.Tus, we have four atention scores. In Figure 2, X-axis represents patients, and Y-axis is the atention weight calculated for each visit. In this case study, we select fve patients. We can observe that for diferent patients, the atention scores learned by the atention mechanism are diferent.

image

Figure 2: Attention Mechanism Analysis.

To illustrate the correctness of the learned atention weights, we provide an example. For the second patient in Figure 2, we list all the diagnosis codes in Table 4. In order to predict the medical codes in the sixth visits, we frst obtain the atention scores α = [0.2386, 0.0824, 0.3028, 0.3762]. Analyzing this atention vector, we can conclude that the medical codes in the second, fourth and ffh visits signifcantly contribute to the fnal prediction. From Table 4, we can observe that the patient sufered essential hypertension in the second and fourth visits, and diagnosed diabetes in

Table 3: Te Accuracy@k of Diagnosis Prediction Task.

image

Table 4: Diagnosis Codes in Each Visit for Patient 2 in the Case Study.

image

the ffh visits. Tus, the probability of the sixth visit’s medical codes about diabetes and diseases related to essential hypertension is high. According to the proposed approach, we can predict the correct diagnoses that this patient sufers diabetes and hypertensive heart disease. Tis case study demonstrates that we can learn an accurate atention weight for each visit, and the experimental results in Section 4.3 also illustrate that the appropriate atention models can signifcantly improve the performance of the diagnosis prediction task in healthcare.

4.5 Code Representation Analysis

Te interpretability of medical codes is important in healthcare. In order to analyze the representations of medical codes learned by the proposedmodel  Dipoleд, we show top ten diagnosis codes with the largest value in each of six columns selected from the hidden representation matrix  W ⊤v ∈ R|C|×min Table 5. In this way, we can demonstrate the characteristic of each column and map each dimension from the code embedding space to the medical concept.

image

In Table 5, we can clearly observe that the codes in all the six dimensions are about diabetes complications, which are in accordance with the complications listed on the American Diabetes Association6. Dimension 10 is related to eye complications and Alzheimer’s disease. Diabetes can damage the blood vessels of the retina (diabetic retinopathy), potentially leading to blindness, and Type 2 diabetes may increase the risk of Alzheimer’s disease. Dimension 38 relates to the complications of neuropathy (nerve damage). Dimension 77 is about heart diseases. It has been established that there is a high correlation between diabetes, heart disease, and stroke. In fact, two out of three patients with diabetes die from heart disease or stroke. Patients with diabetes have a greater risk of depression than people without diabetes. Dimension 79 includes the codes related to mental health. Dimension 141 shows a fact that diabetes may cause skin problems, including bacterial and fungal infections. High blood pressure is one common complication of diabetes shown in dimension 142, which also raises the risk for heart atack, stroke, eye problems, and kidney disease.

4.6 Assumption Validation

In the proposed model, we adopt bidirectional recurrent neural networks to model patient visits instead of recurrent neural networks. To illustrate the beneft of employing bidirectional recurrent neural networks, we analyze the detailed mean accuracy of RNN and Dipole- shown in Figure 3. We frst divide patients into diferent groups based on the number of visits. Te group label is the quotient of the number of visits divided by 15 for the Diabetes dataset and 7 for the Medicaid dataset, which is X-axis in Figure 3. Ten we calculate the weighted average accuracy (Y-axis) of diferent groups, i.e.,�n MAn∗Cn�n Cn , where MAnis the mean accuracy of all the patients with n visits, and  Cnis the number of patients with n visits. From Figure 3, we can observe that the average accuracy of Dipole- is beter than that of RNN in diferent groups. On the Diabetes dataset, the weighted mean accuracy of RNN increases when the number of visits becomes larger. Tis is because the codes of visits on the Diabetes dataset are all about diabetes, and RNN can make correct prediction according to recent visits’ information. However, the codes on the Medicaid dataset are related to

Table 5: Interpretation for Diagnosis Code Representations on the Diabetes Dataset.

image

Figure 3: Weighted Mean Accuracy of Diferent Groups.

multiple diseases, and it is hard to correctly predict the future visit information when the sequences are too long. Tus, the weighted mean accuracy signifcantly drops when the number of visits is large on the Medicaid dataset.

Figure 4 shows the diference of weighted mean accuracy between Dipole- and RNN in diferent groups. We can observe that with the increase of the number of visits, the diference also augments dramatically. It demonstrates that bidirectional recurrent

Figure 4: Diference of Weighted Mean Accuracy.

neural networks can “remember” more information when the sequences are long, and make correct predictions with their memories. Tus, modeling patient visits with bidirectional recurrent neural networks is reasonable.

Diagnosis prediction is a challenging and important task, and interpreting the prediction results is a hard and vital problem for predictive model in healthcare. Many existing work in diagnosis prediction employs deep learning techniques, such as recurrent neural networks (RNNs), to model the temporal and high dimensional EHR data. However, RNN-based approaches may not fully remember all the previous visit information, which leads to the incorrect prediction. To interpret the predicting results, existing work introduces location-based atention model, but this mechanism ignores the relationships between the current visit and the past visits.

In this paper, we propose a novel model, named Dipole, to address the challenges of modeling EHR data and interpreting the prediction results. By employing bidirectional recurrent neural networks, Dipole can remember the hidden knowledge learned from the previous and future visits. Tree atention mechanisms allow us to interpret the prediction results reasonably. Experimental results on two large real world EHR datasets prove the efectiveness of the proposed Dipole for diagnosis prediction task. Analysis shows that the atention mechanisms can assign diferent weights to previous visits when predicting the future visit information. We demonstrate that the learned representations of medical codes are meaningful. Finally, an experiment is conducted to validate the reasonableness and efectiveness of modeling patient visits with bidirectional recurrent neural networks.

Te authors would like to thank the anonymous referees for their valuable comments and helpful suggestions. Tis work is supported in part by the US National Science Foundation under grants IIS-1319973, IIS-1553411 and IIS-1514204. Any opinions, fndings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily refect the views of the National Science Foundation.

[1] Jimmy Ba, Volodymyr Mnih, and Koray Kavukcuoglu. 2015. Multiple Object Recognition with Visual Atention. In Proceedings of the 3rd International Con- ference on Learning Representations (ICLR’15).

[2] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. 2015. Neural Machine Translation by Jointly Learning to Align and Translate. In Proceedings of the 3rd International Conference on Learning Representations (ICLR’15).

[3] Yoshua Bengio, Patrice Simard, and Paolo Frasconi. 1994. Learning Long-Term Dependencies with Gradient Descent is Difcult. IEEE Transactions on Neural Networks 5, 2 (1994), 157–166.

[4] James Bergstra, Olivier Breuleux, Fr´ed´eric Bastien, Pascal Lamblin, Razvan Pas- canu, Guillaume Desjardins, Joseph Turian, David Warde-Farley, and Yoshua Bengio. 2010. Teano: A CPU and GPU Math Compiler in Python. In Proceed- ings of the 9th Python in Science Conference (SciPy’10). 1–7.

[5] Zhengping Che, David Kale, Wenzhe Li, Mohammad Taha Bahadori, and Yan Liu. 2015. Deep computational phenotyping. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’15). ACM, 507–516.

[6] Zhengping Che, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. 2016. Recurrent NeuralNetworks for Multivariate Time Series with Missing Values. arXiv preprint arXiv:1606.01865 (2016).

[7] Yu Cheng, Fei Wang, Ping Zhang, and Jianying Hu. 2016. Risk Prediction with Electronic Health Records: A Deep Learning Approach. In Proceedings of the 2016 SIAM International Conference on Data Mining (SDM’16). 432–440.

[8] Kyunghyun Cho, Bart Van Merri¨enboer, Dzmitry Bahdanau, and Yoshua Ben- gio. 2014. On the Properties of Neural Machine Translation: Encoder-Decoder Approaches. arXiv preprint arXiv:1409.1259 (2014).

[9] Edward Choi, Mohammad Taha Bahadori, Elizabeth Searles, Catherine Cof- fey, Michael Tompson, James Bost, Javier Tejedor-Sojo, and Jimeng Sun. 2016.

Multi-layer Representation Learning for Medical Concepts. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’16). 1495–1504.

[10] Edward Choi, Mohammad Taha Bahadori, Le Song, Walter F Stewart, and Ji- meng Sun. 2017. GRAM: Graph-based Atention Model for Healthcare Representation Learning. In Proceedings of the 23rd ACM SIGKDD International Con- ference on Knowledge Discovery and Data Mining (KDD’17). ACM.

[11] Edward Choi, Mohammad Taha Bahadori, Jimeng Sun, Joshua Kulas, Andy Schuetz, and Walter Stewart. 2016. RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Atention Mechanism. In Advances in Neural Information Processing Systems (NIPS’16). 3504–3512.

[12] Edward Choi, Nan Du, Robert Chen, Le Song, and Jimeng Sun. 2015. Constructing Disease Network and Temporal Progression Model via Contextsensitive Hawkes Process. In 2015 IEEE International Conference on Data Mining (ICDM’15). IEEE, 721–726.

[13] Jan K Chorowski, Dzmitry Bahdanau, Dmitriy Serdyuk, Kyunghyun Cho, and Yoshua Bengio. 2015. Atention-based Models for Speech Recognition. In Ad- vances in Neural Information Processing Systems (NIPS’15). 577–585.

[14] Karl Moritz Hermann, Tomas Kocisky, Edward Grefenstete, Lasse Espeholt, Will Kay, Mustafa Suleyman, and Phil Blunsom. 2015. Teaching Machines to Read and Comprehend. In Advances in Neural Information Processing Systems (NIPS’15). 1693–1701.

[15] Sepp Hochreiter and J¨urgen Schmidhuber.1997. Long Short-Term Memory. Neural Computation 9, 8 (1997), 1735–1780.

[16] Peter B Jensen, Lars J Jensen, and Søren Brunak. 2012. Mining Electronic Health Records: Towards Beter Research Applications and Clinical Care. Nature Reviews Genetics 13, 6 (2012), 395–405.

[17] Alex M Lamb, Anirudh Goyal ALIAS PARTH GOYAL, Ying Zhang, Saizheng Zhang, Aaron C Courville, and Yoshua Bengio. 2016. Professor Forcing: A New Algorithm for Training Recurrent Networks. In Advances In Neural Information Processing Systems (NIPS’16). 4601–4609.

[18] Zachary C Lipton, David C Kale, and Randall Wetzel. 2016. Modeling Missing Data in Clinical Time Series with RNNs. In Proceedings of Machine Learning for Healthcare (MLHC’16).

[19] ChuanrenLiu, Fei Wang, Jianying Hu, and Hui Xiong. 2015. Temporal Phenotyp- ing from Longitudinal Electronic Health Records: A Graph based Framework. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’15). ACM, 705–714.

[20] Minh-Tang Luong, Hieu Pham, and Christopher D Manning. 2015. Efective Approaches to Atention-based Neural Machine Translation. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing (EMNLP’15). 1412–1421.

[21] Fenglong Ma, Chuishi Meng, Houping Xiao, Qi Li, Jing Gao, Lu Su, and Aidong Zhang. 2017. Unsupervised Discovery of Drug Side-Efects from Heterogeneous Data Sources. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’17). ACM.

[22] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jef Dean. 2013. Distributed Representations of Words and Phrases and Teir Compositionality. In  Advances in Neural Information Processing Systems (NIPS’13). 3111–3119.

[23] Phuoc Nguyen, Truyen Tran, Nilmini Wickramasinghe, and Svetha Venkatesh. 2016. Deepr: A Convolutional Net for Medical Records. IEEE Journal of Biomedical and Health Informatics (2016).

[24] Mike Schuster and Kuldip K Paliwal. 1997. Bidirectional Recurrent Neural Net- works. IEEE Transactions on Signal Processing 45, 11 (1997), 2673–2681.

[25] Qiuling Suo, Fenglong Ma, Giovanni Canino, Jing Gao, Aidong Zhang, Pierangelo Veltri, and Agostino Gnasso. 2017. A Multi-task Framework for Monitoring Health Conditions via Atention-based Recurrent Neural Networks. In Proceedings of the AMIA 2017 Annual Symposium (AMIA’17).

[26] Xiang Wang, David Sontag, and Fei Wang. 2014. Unsupervised Learning of Dis- ease Progression Models. In Proceedings of the 20th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’14). ACM, 85–94.

[27] Houping Xiao, Jing Gao, Long Vu, and Deepak S. Turaga.2017. Learning Tempo- ral State of Diabetes Patients via Combining Behavioral and Demographic Data. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’17). ACM.

[28] Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhutdinov, Richard S Zemel, and Yoshua Bengio. 2015. Show, Atend and Tell: Neural Image Caption Generation with Visual Atention. In Proceedings of the 32nd International Conference on Machine Learning (ICML’15). CoRR, 2048– 2057.

[29] Qanzeng You, Hailin Jin, Zhaowen Wang, Chen Fang, and Jiebo Luo. 2016. Im- age Captioning with Semantic Atention. In Proceedings of 2016 IEEE Conference on Computer Vision and Patern Recognition (CVPR’16). 4651–4659.

[30] MathewD Zeiler. 2012. ADADELTA: An Adaptive Learning Rate Method. arXiv preprint arXiv:1212.5701 (2012).

[31] Jiayu Zhou, Jimeng Sun, Yashu Liu, Jianying Hu, and Jieping Ye. 2013. Patient Risk Prediction Model via Top-k Stability Selection. In Proceedings of the 13th SIAM International Conference on Data Mining (SDM’13). SIAM, 55–63.

[32] Jiayu Zhou, Fei Wang, Jianying Hu, and Jieping Ye. 2014. From Micro to Macro: Data Driven Phenotyping by Densifcation of Longitudinal Electronic Medical Records. In Proceedings of the 20th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’14). ACM, 135–144.

[33] Jiayu Zhou, Lei Yuan, Jun Liu, and Jieping Ye. 2011. A Multi-Task Learning Formulation for Predicting Disease Progression. In Proceedings of the 17th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD’11). ACM, 814–822.


Designed for Accessibility and to further Open Science