b

DiscoverSearch
About
My stuff
Continual Learning for Domain Adaptation in Chest X-ray Classification
2020·arXiv
Abstract
Abstract

Over the last years, Deep Learning has been successfully applied to a broad range of medical applications. Especially in the context of chest X-ray classification, results have been reported which are on par, or even superior to experienced radiologists. Despite this success in controlled experimental environments, it has been noted that the ability of Deep Learning models to generalize to data from a new domain (with potentially different tasks) is often limited. In order to address this challenge, we investigate techniques from the field of Continual Learning (CL) including Joint Training (JT), Elastic Weight Consolidation (EWC) and Learning Without Forgetting (LWF). Using the ChestX-ray14 and the MIMICCXR datasets, we demonstrate empirically that these methods provide promising options to improve the performance of Deep Learning models on a target domain and to mitigate effectively catastrophic forgetting for the source domain. To this end, the best overall performance was obtained using JT, while for LWF competitive results could be achieved - even without accessing data from the source domain.

Keywords: Convolutional Neural Networks, Continual Learning, Catastrophic Forgetting, Chest X-Ray, ChestX-ray14, MIMIC-CXR, Joint Training, Elastic Weight Consolidation, Learning Without Forgetting.

The availability of multiple hospital-scale chest X-ray datasets and the advances in the field of Deep Learning have facilitated the development of techniques for automatic image interpretation. Using Convolutional Neural Networks (CNNs), for multiple findings, performance levels were reported which are on par, or even superior to those of a experienced radiologist. Following the promising results of CNNs for Pneumonia detection in chest X-rays (Rajpurkar et al., 2017), the success of these methods has been transferred to Cardiomegaly, Edema, and Pleural Effusion (Irvin et al., 2019). More recently, for all find-ings in the ChestX-ray14 dataset (Wang et al., 2017), a performance similar to radiologists was reported (Majkowska et al., 2019). At the same time, it has been noted that these models can be subject to substantial performance degradations when applied to samples from another dataset or domain (Zhang et al., 2019; Yao et al., 2019). In single chest X-ray studies, it is commonly expected that the data is independent and identically distributed among the training and test set, a common assumption in Machine Learning. Contrary, when comparing different publicly available chest X-ray datasets, often a significant domain

image

bias can be observed. Such differences in the data distribution pose a severe challenge for the development, evaluation and validation of medical devices. The occurrence of such domain-dependent distribution shifts could be, for example, explained by hospital specific processes (including machine protocols and treatment policies), the patient population as well as demographic factors (Zhang et al., 2019). Furthermore, the data collection strategy and the employed labeling techniques impact the specific characteristics of a dataset. The development of domain-invariant predictors has received increased interest, including methods based on bias regularized loss functions and domain augmentation (Zhang et al., 2019) as well as a simultaneous training on multiple datasets (Yao et al., 2019). These approaches have shown great promise to mitigate the effect of a domain shift, but were developed for a one-time optimization prior to model deployment. On the other hand, for most Deep Learning algorithms it is rather straight forward to implement some basic functionality which allows to learn in a continuous fashion (even after deployment) and to improve over time. This includes the adaptation to a new domain and new tasks. This has also be noted by the FDA in a recent discussion about the regulatory implications of modifications to AI/ML-based medical devices, contrary to ”locked” software (U.S. Food & Drug Administration, 2019).

Therefore, in this contribution, we approach this challenge from a different perspective i.e. using methods from the field of Continual Learning (CL). Traditionally, in Continual Learning, methods are considered for the sequential learning of individual tasks (Parisi et al., 2019), a concept with great potential for the adaptation of chest X-ray models to a new domain. However, a fundamental problem in CL is catastrophic forgetting (McCloskey and Cohen, 1989), i.e. a phenomenon which is associated with performance degradations for previously learned tasks when a model is adapted to a new task. For chest X-ray clas-sification, this could result not only in a reduced detection performance for unique findings from the source domain, but the model could unlearn to classify data from the source domain in general. Baseline techniques such as Joint Training (JT), try to alleviate this problem by means of integrating data from the source domain into the learning process -an approach which is not always feasible when sensitive healthcare information is considered. Regularization-based CL techniques, such as Elastic Weight Consolidation (EWC) (Kirkpatrick et al., 2017) and Learning Without Forgetting (LWF) (Li and Hoiem, 2017), introduce prior information or soft-targets in order to avoid the need for memorizing old data. In order to evaluate the feasibility of CL techniques, we assess their performance in an empirical study using the ChestX-ray14 and the MIMIC-CXR (Johnson et al., 2019) dataset.

In this sections we will provide a brief summary of three Continual Learning concepts JT, EWC and LWF. The latter two methods follow simple regularization paradigms and do not require the storing of training data from previous tasks or domains. All methods are easy to implement and do not entail a large computational overhead compared to the original model training. Following the conventions from the CL literature, we use a task-centered formalism to describe the CL methods. In our chest X-ray scenario described above, the first and second task correspond to solving the ChestX-ray14 and MIMIC-CXR classification problem, respectively. In the rapidly growing field of CL other methods, for example, relying on episodic memory, generative models or architectural changes of the network, have been proposed. For a broader overview we refer the interested reader to (Parisi et al., 2019), (Caruana, 1997) and the references therein.

2.1. Joint training (JT)

Suppose that  Ti = (xi,j, yi,j)j=1,...,Ni, i ∈ Iis a sequence of tasks where  xi,jdenotes the j-th sample of task i and denotes  yi,jthe corresponding label. A neural network with weight vector  θis used to model the predictive distribution  p(y|θ, x) of unobserved labels y associated to observed samples x. The model fit is typically conducted by empirical risk minimization. Hence, for each individual task  Ti, the task-specific optimal weight vector  θiis obtained by solving a minimization problem of the type

image

A joint training strategy (JT) aims at improving the model performance on different tasks simultaneously by combining the task-specific training datasets. For example, given a subset of tasks  J ⊂ Ithe optimal weight vector  θJon the combined task  TJ := ∪i∈J Tiis obtained by solving a minimization problem of the type

image

allowing the learning process to exploit commonalities and differences across different tasks. This can improve the predictive performance when compared to training multiple task-specific models separately or training a single model in a simple sequential fashion which is prone to catastrophic forgetting, cf. (Caruana, 1997). Unfortunately, in many real world scenarios, the aggregation of large heterogeneous training datasets (e.g. for chest X-ray classification) is subjected to various limitations. In particular, task-specific data used for model training may no longer be available at some future time point when data associated to a new task is obtained and model fine-tuning becomes necessary. For example, this situation occurs in a clinical setting when an already deployed model needs to be adjusted to data acquired on-site.

2.2. Elastic Weight Consolidation (EWC)

Various CL approaches for explicitly modeling cross-correlations between distinct tasks have been proposed. Elastic weight consolidation assumes a prior distribution  p(θ|Ti−1) on thenetwork weights  θduring the model adaptation for task  Ti. The prior  p(θ|Ti−1) is selected in such a way that it captures basic statistical properties of the empirical distribution of the network weights across the previous task  Ti−1. Finally, the optimal parameter for the current task  Tiis obtained as the maximum a posteriori estimate

image

In contrast to memory-based methods, EWC acts as a simple regularizer on the training objective and does not rely on storing any additional data associated to previous tasks. The key assumption of EWC is that enough information about previous tasks can be encoded within the model weight prior distribution in order to prevent a severe performance degradation when moving to a new task. Owing to their computational tractability, frequent choices for  p(θ|Ti−1) are multivariate Laplace or Gaussian distributions. In the Gaussian case  p(θ|Ti−1) = N(θ|µi−1, Σi−1) we obtain

image

with a constant  λ >0 which allows to regulate the impact of the prior. Choosing the parameters  µi−1 = θi−1 and Σ−1i−1 = diag(Fi−1), where Fi−1denotes the empirical Fisher matrix associated to task  Ti−1, i.e.

image

yields the EWC objective from (Kirkpatrick et al., 2017). It is well known that under mild regularity assumptions (5) constitutes an approximation to the empirical Hessian of the negative log-likelihood (NLL) with respect to  θ, i.e.

image

holds true. Consequently, the entries of diag(Fi−1) may be considered as approximations to the non-mixed second derivatives of the NLL, which reflect to some extend the sensitivity of the model output with respect to marginal changes in the network weights. As argued in (Kirkpatrick et al., 2017), second derivatives of large magnitude attribute a high importance of the corresponding model parameter for solving the task  Ti−1. Consequently, the quadratic penalty term in (4) discourages strong deviations from the previous task’s parameter  θi−1in the sensitive weight space directions.

In our experiments we used the Gaussian EWC objective (4) with a binarized Fisher information. That is to say, given a threshold  ρ >0 we chose the binary inverse covariance matrix Σ−1i−1 = diag(Fi−1 > ρ). Consequently, all network parameters with a sensitivity below  ρare not affected by the regularization. All other parameters are shrunk towards µi−1uniformly with the rate  λ. We found it useful to select  ρbased on the distribution of the main diagonal entries of  Fi−1. For example, setting  ρto the 95%-quantile imposes a uniform regularization on the 5% most sensitive network weights. The intuition behind this binarized EWC version is rather simple: we decompose the weight space of the neural network into a subspace containing the sensitive dimensions and its complement. Then a uniform L2-regularization is applied to the weight vector projected on the “sensitive” subspace. Clearly, the computational overhead of this binary EWC is lower compared to classic EWC. In summary, by imposing a prior  p(θ|Ti−1) on the model weights, deviations from  θi−1are penalized while learning the task  Ti parameter θi. The magnitude of the penalty depends of the choice on the prior. For example, prior distributions which are highly concentrated at  θi−1may severely constrain the flexibility of the model to adapt to the new task  Tiin favor of preserving the model performance on  Ti−1. Elastic weight consolidation acts as a regularizer for the current task’s model weights and does not require to store the training data from previous tasks.

2.3. Learning Without Forgetting (LWF)

The key idea of the Learning Without Forgetting method is to introduce a soft-target regularization into the training loss associated to the current task which reflects the behavior of the model associated to the previous task on the dataset at hand.

In more detail: When moving to a new task  Ti = (xi,j, yi,j)j=1,...,Niwe apply the previous model  Mθi−1which was trained on  Ti−1to the current task’s training samples  xi,j in orderto generate “synthetic labels” ˆyi,j := Mθi−1(xi,j) which record the model behavior. Please note that the raw model outputs ˆyi,jcorrespond, depending on the implementation, to float-valued tensors rather than integer class assignments. By adding a regularization term to the loss functional (1), a bias towards a consistent behavior of the models  Mθi and Mθi−1on the current task’s training samples is introduced. The task  Tioptimal model weight  θiis then obtained by solving a minimization problem of the type

image

Increasing the parameter  λ >0 decreases the relevance of the “hard-labels”  yi,jassociated to  Tiand instead rewards model output patterns which are consist with the previous model. For a detailed discussion of LWF in the classification setting we refer the reader to (Li and Hoiem, 2017). This basic concept can be implemented and extended in various ways. For example, in the classification setting the soft-target concept can be used to fill missing labels when fine-tuning a model on a new dataset where only partial annotations are available. Similar to EWC, this approach acts as a mere regularizer for the current task’s model weights. Access to the previous task’s training data is not required.

2.4. Datasets

In following we consider the datasets ChestX-ray14 (Wang et al., 2017) and MIMIC-CXR (Johnson et al., 2019). The ChestX-ray14 data was released in 2017 by the NIH Clinical Center and consists of 112120 chest X-ray images (AP/PA) from 30805 patients. The images in the dataset were annotated with respect to 14 different findings using an NLPbased analysis of the radiology reports (with an additional ”No Findings” label which is typically not considered).

The MIMIC-CXR dataset (consortium version v2.0.0) consists of X-ray images (DICOM) and radiology reports from the Beth Israel Deaconess Medical Center in Boston. For model training and evaluation, we filtered the DICOM images (based on the DICOM attributes ImageType, PresentationIntentType, PhotometricInterpretation, BodyPartExamined, ViewPosition and PatientOrientation) resulting in a dataset with 226483 images from 62568 patients. In order to generate annotations, we applied the CheXpert labeler to the impression section of the reports, yielding annotations for 13 findings and a ”No Finding” label (Irvin et al., 2019)1. In contrast to the ChestX-ray14 dataset, for MIMICCXR no official train/test split is available. Therefore, we selected randomly 80% of the patients for training while the remaining 20% were assigned to the test split. For the following experiment, it is assumed that matching labels (including ”Effusion” and ”Pleural Effusion”) represent comparable concepts in both datasets. Consequently, we consider in total 21 labels with 7 unique findings for each dataset and 7 findings occurring in both datasets.

2.5. Experimental Design

In order to investigate the impact of a domain shift in the data distribution and the potential benefit of the CL methods outlined in 2.1, 2.2 and 2.3, a set of networks was adapted first to ChestX-ray14 and subsequently to MIMIC-CXR. To this end, a pre-trained DenseNet121 (Huang et al., 2017) was selected as a starting point as it is one of the most commonly employed neural network types in the X-ray domain. In order to account for the changed number of labels and the multi-label classification task, the last layer was replaced by a randomly initialized linear layer and a sigmoid activation function. For the first and second adaptation step a similar hyper-parameter setup was employed: Binary cross entropy was used as a loss function, while for all training scenarios - except LWF - the computation of the loss (training and validation) was restricted to the labels from the current domain. Stochastic gradient descent with momentum was used as update rule, with an initial learning rate of 0.01, a momentum of 0.9 and a mini-batch size of 16. For the adaption to ChestX-ray14 a L2weight decay of 0.0001 was employed, whereas for the MIMIC-CXR task, weight decay was disabled. After each epoch, the learning rate was reduced by a factor of 10 if the validation loss did not improve. During the training, the images in a mini-batch were subject to data augmentation with a probability of 90%. Our data augmentation included common strategies such as: scaling (±15%), rotation around the image center (±5◦), translation relative to the image extend (±10%) as well as mirroring along the midsagittal plane (50% chance). Finally, all images were rescaled to 224  ×224 pixel in order to match the input size of the DenseNet121 architecture. After training, the network with the lowest validation loss was used for the processing of the test dataset.

The ChestX-ray14 model was adapted on the MIMIC-CXR dataset using four different training strategies:

1. A standard fine-tuning of the networks using the MIMIC-CXR data only.

2. A JT setup where 20%, . . . , 100% of the ChestX-ray14 data was included into the adaptation process in addition to the MIMIC-CXR data, cf. Section 2.1.

3. Fine-tuning on the MIMIC-CXR using (binary) EWC regularization with a Gaussian prior distribution on the model weights and an impact of  λ = 0.001, cf. Section

2.2. The mean of the prior was set to the parameter vector of the ChestX-ray14 model. As inverse covariance matrix in the EWC objective (4) we chose the binarized diagonal empirical Fisher matrix calculated over the ChestX-ray14 training samples with sensitivity threshold of  ρ = 0.001.

4. Fine-tuning on the MIMIC-CXR data using LWF regularization with an impact parameter  λ = 2.0, cf. Section 2.3. The LWF penalty was only applied to the 7 labels not present in the MIMIC-CXR dataset, i.e. soft-targets were generated for the labels Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule and Pleural thickening. Hence, in the LWF setting all 21 labels from both domains are considered during the

model adaptation to the MIMIC-CXR domain. The validation loss is only computed on the domain specific validation data containing 14 labels.

All experiments were repeated 5 times with resampled validation sets (using 10% of all patients). Our experiments were conducted using PyTorch 1.1 on a machine with one Nvidia GeForce RTX 2080 TI graphic card.

Our quantitative results in terms of average AUC values for each finding along with their standard deviations are summarized in Table 1. In the upper row the model performance on the ChestX-ray14 dataset is given, while the bottom row corresponds to the performance on the MIMIC-CXR dataset. The left column (Initial) indicates the performance after an initial training on ChestX-ray14, whereas the right columns (JT-0%, JT-20%, . . . , LWF) contain the results after the model adaptation to MIMIC-CXR. When applying the models trained on ChestX-ray14 directly to the MIMIC-CXR data, a decreased performance for the classes Cardiomegaly, Edema, Pneumonia and Pneumothorax can be observed. This indicates that the source domain training data is not representative enough for the target domain data distribution. The strongest decrease is observed for Cardiomegaly with a drop from 0.8806 to 0.7603 mean AUC. For the classes Atelectasis, Consolidation and Effusion the performance on the target domain is comparable or even slightly superior, see lower left quadrant of Table 1. As a consequence of the domain shift, the average AUC across all labels decreases from 0.8106 to 0.7833 making model adaptation unavoidable. The lower right quadrant shows that all CL methods achieve a formidable on-domain model performance on the MIMIC-CXR data with average AUC values across all findings ranging from 0.8190 to 0.8257. In particular, this indicates that both regularization approaches (LWF and EWC) still allow for enough flexibility that the model can adjust to the new domain.

However, a simple adaptation to MIMIC-CXR with no CL strategy (JT-0%) leads to a decrease of the mean AUCs on the ChestX-ray14 domain for all classes except Infiltration, Pneumonia and Pneumothorax. The effect of catastrophic forgetting becomes more evident in Figure 1, which depicts the (averaged) Forward (FWT) and Backward-Transfer (BWT) for all findings. These concepts were introduced by (Lopez-Paz and Ranzato, 2017) in order to measure the knowledge transfer across a sequence of tasks.2 The BWT measures the changes of model performance on a task  Tiafter adapting to a new task  Ti+1. In detail, for each individual label the BWT is computed by subtracting the task  TiAUC values (prior to adapting the model to  Ti+1) from the task  Ti+1AUC values. A negative BWT is often associated with catastrophic forgetting. Contrary, a positive BWT is obtained if the performance on the previous task is increased. Similarly, the FWT measures the effect of learning a task  Tion the performance of a future task  Ti+1which was not seen during training. In detail, for each individual label the FWT is computed by subtracting 0.5 (AUC of random classifier) from the task  Ti+1AUC values (without adapting the task  Ti model toTi+1). While the ChestX-ray14 models achieve a moderate FWT on MIMIC-CXR, the low BWT indicates a considerable drop in performance on ChestX-ray14 after the adaptation

image

Figure 1: Left: Backward Transfer on ChestX-ray14 after adaptation using different Con- tinual Learning (CL) techniques. Right: Forward-Transfer (FWT) for a chest X-ray14 model on MIMIC-CXR. Bars indicate min, mean and max.

(JT-0%). Integrating data from ChestX-ray14 into the training on the new domain allows to mitigate this effect (JT-20%,. . . , JT-100%). We observe that the BWT is positively correlated with the amount of additional samples from ChestX-ray14. Not surprisingly, the best model performance is achieved on the combined dataset containing all training samples from both domains (JT-100%). As argued above, in real world scenarios access to old training data might be limited or not possible at all. Consequently, the regularization based methods LWF and EWC which do not rely on storing data from previous tasks or domains are of high practical relevance. In our experiments, LWF outperformed the EWC approach and achieved a performance on the original domain between JT-60% and JT-80% (and superior to the original model) without accessing any data from ChestX-ray14.

In this paper we investigated the applicability of different Continual Learning methods for domain adaptation in chest X-ray classification. To that end, a DenseNet121 was trained on ChestX-ray14 and subsequently fine-tuned on MIMIC-CXR using different Continual Learning strategies (JT, EWC, LWF) in order to adapt to the new domain without severe performance degradations on the original data. The motivation for choosing these datasets as distinct domains, was to simulate a realistic domain shift as encountered in clinical practice. Our quantitative evaluation, including the measurement of Backward and Forward Transfer, confirmed that employing these methods indeed improves the overall model performance, compared to a simple continuation of the model training on the new domain. The best performance was achieved by JT-100%, i.e. training the model on the entire combined datasets from both domains. However, in real world scenarios, e.g. adapting models which are already deployed in the clinic, for legal and privacy reasons it is questionable that the data used for training the original model is always accessible. Hence, the EWC and LWF methods which do not rely on old training samples are of high practical relevance. Our experiments indicate that these regularization techniques indeed allow a model adaption to the target domain while preserving a performance on the original domain which is still close to the JT baseline.

image

Rich Caruana. Multitask learning. Machine learning, 28(1):41–75, 1997.

Gao Huang, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q Weinberger. Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4700–4708, 2017.

Jeremy Irvin, Pranav Rajpurkar, Michael Ko, Yifan Yu, Silviana Ciurea-Ilcus, Chris Chute, Henrik Marklund, Behzad Haghgoo, Robyn Ball, Katie Shpanskaya, et al. CheXpert: A large chest radiograph dataset with uncertainty labels and expert comparison. arXiv preprint arXiv:1901.07031, 2019.

A. Johnson, T. Pollard, R. Mark, S. Berkowitz, and S Horng. MIMIC-CXR database, PhysioNet., 2019.

James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka GrabskaBarwinska, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114(13):3521–3526, 2017.

Zhizhong Li and Derek Hoiem. Learning without forgetting. IEEE transactions on pattern analysis and machine intelligence, 40(12):2935–2947, 2017.

David Lopez-Paz and Marc’Aurelio Ranzato. Gradient episodic memory for continual learn- ing. In Advances in Neural Information Processing Systems, pages 6467–6476, 2017.

Anna Majkowska, Sid Mittal, David F Steiner, Joshua J Reicher, Scott Mayer McKinney, Gavin E Duggan, Krish Eswaran, Po-Hsuan Cameron Chen, Yun Liu, Sreenivasa Raju Kalidindi, et al. Chest radiograph interpretation with deep learning models: Assessment with radiologist-adjudicated reference standards and population-adjusted evaluation. Radiology, page 191293, 2019.

Michael McCloskey and Neal J Cohen. Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pages 109–165. Elsevier, 1989.

German I Parisi, Ronald Kemker, Jose L Part, Christopher Kanan, and Stefan Wermter. Continual lifelong learning with neural networks: A review. Neural Networks, 2019.

Pranav Rajpurkar, Jeremy Irvin, Kaylie Zhu, Brandon Yang, Hershel Mehta, Tony Duan, Daisy Ding, Aarti Bagul, Curtis Langlotz, Katie Shpanskaya, et al. Chexnet: Radiologist-level pneumonia detection on chest X-rays with deep learning. arXiv preprint arXiv:1711.05225, 2017.

U.S. Food & Drug Administration. Proposed Regulatory Framework for Modifications to Artificial Intelligence/Machine Learning (AI/ML)-Based Software as a Medical Device (SaMD). 2019.

Xiaosong Wang, Yifan Peng, Le Lu, Zhiyong Lu, Mohammadhadi Bagheri, and Ronald M Summers. ChestX-ray8: Hospital-scale chest X-ray database and benchmarks on weaklysupervised classification and localization of common thorax diseases. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 2097–2106, 2017.

Li Yao, Jordan Prosky, Ben Covington, and Kevin Lyman. A strong baseline for domain adaptation and generalization in medical imaging. arXiv preprint arXiv:1904.01638, 2019.

Yundong Zhang, Hang Wu, Huiye Liu, Li Tong, and May D Wang. Mitigating the effect of dataset bias on training deep models for chest X-rays. arXiv preprint arXiv:1910.06745, 2019.


Designed for Accessibility and to further Open Science