Multiple Sclerosis (MS) is an inflammatory demyelinating disease of the brain and spine characterized by the presence of hyperintense lesions on T2 weighted Magnetic resonance imaging (MRI) images [11]. Lesion volume measured on MRI is a clinically important marker of disease progression, which can be used to monitor and guide treatment. Currently, MS lesion segmentation is often performed manually or requires extensive manually editing, which is tedious, time-consuming and highly operator-dependent. In the brain, MS lesions can vary greatly in terms of location, size, shape, and conspicuity on T2 weighted images (Fig. 1) Therefore, a fast and fully automated lesion segmentation tool is highly desired to improve accuracy and reproducibility while saving time in the clinic [9].
Fig. 1: Example T1, T2, T2-FLAIR images and corresponding mask traced by a human expert and marked in red.
In recent MS lesion segmentation algorithms, deep neural networks [1,6,13] which learn the mapping between images and lesion masks have achieved the best performance. Patch based methods [1,6] which extract small patches for classi-fication of the central voxel suffers from slow training and ignores global brain structure information. Two-dimensional methods [6, 12] that are based on inplane segmentation does not utilize contextual information along the slice direction. Three-dimensional convolutional neural networks (CNN) [4,13] have been proposed to capture image dependencies between consecutive slices. However, by trading-off between the convolution kernel size and the number of pooling layers, these methods can only capture limited local receptive field and short-range dependencies.
kU-Net [3] which combines U-Net [12] and recurrent neural network (RNN), is an effective approach to capture dependencies between different slices along the axial direction. kU-Net can be further improved [13] by exploiting multi-modal MRI images and applying long short term memory (LSTM) network. However, these methods only capture the axial connections and ignore the sagittal and coronal directions in 3D images. Additionally, RNN and LSTM are inherently limited to short-range dependencies [2, 10, 14]. Non-local neural networks [15] incorporates a powerful self-attention mechanism [14] to aggregate contextual information for every single pixel from all other pixels. The rich dependency information captured by this method has been shown to benefit various applications [5, 8] in image and video processing. However, this non-local method needs to compute very large attention maps (e.g. 3D MRI image volume with size 512 512
256 requires attention map with size 67108864
67108864 ), which greatly increases computational and memory demands, and also makes the model prone to overfitting in the case of limited samples. Currently, leveraging the contextual and domain-specific MRI information for automatic MS lesion segmentation remains a non-trivial task.
The contributions of our RSANet are three folds. First, unlike methods [3,13] using RNN or LSTM to capture the slice-wise dependencies, where RNN and LSTM have inherent drawbacks [2,10,14] of capturing long-range dependencies, we propose a novel slice-wise attention module, called SA Block (see Fig. 2) to compute the response at a slice as a weighted sum of the features from all slices
Fig. 2: Details of the SA block.
along the same direction. Second, we propose a recurrent slice-wise attention module, named RSA block, to capture long-range dependencies from all voxels. The input features passed into RSA block will be recurrently fed into SA blocks along sagittal, coronal and axial directions (see Fig. 4). All three SA blocks share the same parameters, and each SA block will aggregate information from previous SA blocks and finally captures the voxel-wise dense dependencies. Third, RSANet is memory and computationally friendly. Compared with the non-local method [15], our method reduces GPU memory consumption by at least 28and the number of floating point operations by at least 100
in computing the attention maps. We then demonstrate the efficiency and effectiveness of our RSANet on MS lesion segmentation in a dataset with 43 patients.
In this section, we discuss the details of our proposed RSANet for MS lesion segmentation. We will present the slice-wise Attention (SA) block, a basic building block of RASNet. We will then show how to construct the recurrent slice-wise attention (RSA) block with SA block to capture the voxel-wise long-range dependencies.
2.1 Slice-wise Attention Block
Global brain structure and dependencies between lesion locations and brain areas are important features in MS, giving rise to sequential slice dependencies on MRI [3, 13]. In order to capture the long-range dependencies between different slices, we propose a slice-wise attention block, which is self-manipulated and could be put into any layer of an existing CNN architecture.
As shown in Fig. 2, using SA block along the axial direction as an example, given a local feature map , we first apply a set of tensor transformations to the feature map M. Specifically, we reshape-transpose M to
, and reshape-transpose M to
(CHW means the product of C, H, W). We then apply a matrix multiplication between
and
and followed by a softmax operation on the multiplication result to get the slice-wise attention map A as follows:
Fig. 3: Details of the RSA block. The input is the feature map M. RSA block takes M and recurrently produces as intermediate results. Finally, RSA block will output E, where each voxel is a weighted sum of all other voxels.
where measures the impact of
slice on the
slice, and each row of the attention map A is the weight that will be used to aggregate impact from all other slices. Also,
:] denotes the
row of
, and
] denotes the
column of
. We then attend
to A by another matrix multiplication, followed by a reshape-transpose (denoted as RT() in equations) operation and an element-wise sum operation with original M to get the final output
:
where is a scaling parameter that will be updated with other parameters through back-propagation. Through Eqn. (1) and Eqn. (2), we can see that the output of SA block is a linear combination of input feature map and features aggregated from the weighted sum of all other slices, which captures the long-range dependencies between slices. Unlike previous work [3, 13] which uses RNN and LSTM with inherent short-range problems [2, 10], we model slice dependencies by exploring spatial contextual information through all slices. Besides, with the recurrent module introduced in the next section, we can fuse the information efficiently from three slice directions.
2.2 Recurrent Slice-wise Attention Block
Dense contextual dependency information helps regularize the gradient information propagation through the whole network, and in our case, this regularization could help the deep network understand the complicated distributions of MS lesions as well as the relationship between brain structure and lesion masks. Therefore, it is imperative to capture the long-range dependencies among all voxels and in the meanwhile avoid intolerable memory and computation consumption brought by non-local block [15]. We propose our recurrent slice-wise attention (RSA) block, which combines the long-range dependencies among three slice directions in a recurrent manner.
The overall RSA block is shown in Fig. 3. The RSA block consists of three SA blocks with different slice directions i.e. sagittal, coronal and axial. Each SA block in the RSA block share the same convolutional kernel parameters, but their weight parameter is updated independently. In the first recurrent loop, the RSA block takes the input feature M extracted by previous convolutional
Fig. 4: An example of information propagation in RSA block.
layers and outputs where the image slices are sagittally attended. We then repeat the other two SA blocks on coronal and axial directions and finally get the output E which is attended densely among all voxels. The
have the same shape, and thus this RSA block can be put anywhere in the convolutional network.
Single SA block can only capture the slice-wise dependencies along one direction, so we recurrently perform SA block on three slice directions to make up the deficiency of single SA block and obtain the global voxel-wise contextual information. Details of the information propagation for our RSA block is shown in Fig. 4. Suppose the input of the RSA block is , and the red cubes denote the first slice of M along the axial direction (green for the second slice and yellow for the third slice). We show how adjacent voxels aggregate the central voxel in the RSA block and similar procedures can be applied to all the other voxels. In the first SA block,
1]
1], M[1, 2, 1])+
1]), where SA1() denotes the weighted sum of M[1, 0, 1], M[1, 2, 1] and the weighting matrix is obtained by the attention map of this SA block. Then in the second SA block,
1]
0]
2]) +
1]), where SA2() is the same function as SA1() but acts on different inputs and weights. Finally, we can obtain E[1, 1, 1]
1]
1]) +
1]).
In general our RSA block help gather global contextual information by weighted sum from all voxels, and in the meanwhile dramatically reduce the computational cost and memory usage when computing the attention map.
We demonstrate the performance of the proposed RSANet on MRI images acquired with three different contrast (T1, T2, and T2FLAIR) on a GE 3T scanner. A total of 43 MS patients were included in this study. The size of each MRI image dataset varies from 230 320
44 to 260
320
60 with a voxel size of 0
. Images were linearly co-registered using FSL neuroimaging toolbox (FLIRT command). MS lesions were manually segmented by a neural radiologist with six-year experience.
3.1 Training, Testing and Implementation Details
The proposed RSANet was implemented using PyTorch and 3D U-Net [4,12] is used as our backbone network structure.
Loss Function. The data imbalance problem is critical in MS lesion segmentation, as the ROIs usually compose a tiny portion of the MRI images. In our dataset, the rate of MS lesion is only 510
. Therefore we adopted an exponentially weighted CrossEntropy as our loss function. The weight for MS lesion ROIs is
the portion of MS lesion in the MRI images.
Training parameters. We randomly split our dataset into two subset, 20 images for training and the rest 23 for testing. We evaluated each method on five random splits and take the average as the result to compare performance. As the image size varies from patient to patient, we performed random crop with fixed cropping size on the original images to make sure they were in the same size for training. ADAM [7] with the initial learning rate of 1and weight decay of 1
was used to train each method. Each method was trained with a batch size of 4 on a single Titan XP GPU, and training would stop after 800 iterations.
Implementation Details. As both non-local block and RSA block could be put anywhere in the network, we tried three different ways of putting these blocks and compare their performance. Non-local block can only be put in the layers close to the bottom of 3D U-Net [4,12], as it would cause serious memory problem if it is put in the higher layers. Therefore, for fair comparison, we choose three different ways to put these blocks: 1) Single bottom layer, denoted as NCL-010 and RSA-010; 2) Encoder and decoder of the second bottom layer, denoted as NCL-101 and RSA-101; 3) All three places mentioned above, denoted as NCL-111 and RSA-111.
3.2 Quantitative and Qualitative Results
We compared our models with several advanced approaches, including 3D UNet [4,12] and different settings of non-local neural networks [15]. For fair comparison, we obtained these methods from public implementations and adjusted their parameters to get the best performance. We used Dice similarity coefficient (Dice) and Intersection of Union (IoU) as our evaluation metrics. As the number MS lesion vary greatly from patient to patient, simple average of all samples on Dice and IoU would cause bias on evaluating the real performance. Accordingly, we proposed a new evaluation method, in which all voxels from all samples were pooled to compute “voxel average” Dice and IoU as opposed to the conventional sample average.
Quantitative Results. As shown in Table. 1, when 3D U-Net equipped with non-local block, it outperformed original 3D U-Net in all metrics regardless of where we put the non-local block. Especially in the case of NCL-010, it outperformed 3D U-Net by 0.83% of Voxel Avg. Dice and 1.0% of Voxel Avg. IoU, the
Table 1: Quantitative comparison of MS lesion segmentation
Fig. 5: Example segmentation result. From left to right are ground truth label, results of RSA-111, NCL-010 and 3D U-Net.
result of which is consistent as reported in non-local method [15]. RSANet outperformed both 3D U-Net and non-local net with different non-local positions in all metrics (Table. 1). Furthermore, the sample average Dice and IoU score of RSANet increased with the number of RSA blocks in the network. This property is beneficial as our RSA blocks barely increase the cost of floating computations and memory usage.
Qualitative Results. We choose one slice from a testing image, and compare the qualitative results of different models with ground truth labels. As we can see from Fig. 5, since both 3D U-Net and non-local network are not able to efficiently capture the long-range dependencies between MS lesions and brain structure, they suffer from an over-segmenting problem.
We presented a novel recurrent slice-wise attention network, which incorporates three slice-wise attention blocks recurrently. Our proposed method can capture the long-range dependencies within the MRI images for MS lesion patients, which exploit the contextual information between the brain structure and lesion masks. Our method not only achieves the high accuracy on MS lesion segmentation tasks, but also reduces dramatically the computational cost and GPU memory usage. Experimental results showed that our method outperformed other state-of-the-art methods. Our method can be put anywhere in the deep network and thus has the potential for other 3D medical image segmentation tasks.
1. Akkus, Z., Galimzianova, A., Hoogi, A., Rubin, D.L., Erickson, B.J.: Deep learning for brain mri segmentation: state of the art and future directions. Journal of digital imaging 30(4), 449–459 (2017)
2. Bahdanau, D., Cho, K., Bengio, Y.: Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473 (2014)
3. Chen, J., Yang, L., Zhang, Y., Alber, M., Chen, D.Z.: Combining fully convolu- tional and recurrent neural networks for 3d biomedical image segmentation. In: Advances in neural information processing systems. pp. 3036–3044 (2016)
4. C¸i¸cek, ¨O., Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3d u-net: learning dense volumetric segmentation from sparse annotation. In: International conference on medical image computing and computer-assisted intervention. pp. 424–432. Springer (2016)
5. Du, Y., Yuan, C., Li, B., Zhao, L., Li, Y., Hu, W.: Interaction-aware spatio- temporal pyramid attention networks for action classification. In: Proceedings of the European Conference on Computer Vision (ECCV). pp. 373–389 (2018)
6. Kamnitsas, K., Ledig, C., Newcombe, V.F., Simpson, J.P., Kane, A.D., Menon, D.K., Rueckert, D., Glocker, B.: Efficient multi-scale 3d cnn with fully connected crf for accurate brain lesion segmentation. Medical image analysis 36, 61–78 (2017)
7. Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
8. Li, G., He, X., Zhang, W., Chang, H., Dong, L., Lin, L.: Non-locally enhanced encoder-decoder network for single image de-raining. In: 2018 ACM Multimedia Conference on Multimedia Conference. pp. 1056–1064. ACM (2018)
9. Llad´o, X., Oliver, A., Cabezas, M., Freixenet, J., Vilanova, J.C., Quiles, A., Valls, L., Rami´o-Torrent`a, L., Rovira, `A.: Segmentation of multiple sclerosis lesions in brain mri: a review of automated approaches. Information Sciences 186(1), 164– 185 (2012)
10. Luong, M.T., Pham, H., Manning, C.D.: Effective approaches to attention-based neural machine translation. arXiv preprint arXiv:1508.04025 (2015)
11. Milo, R., Kahana, E.: Multiple sclerosis: geoepidemiology, genetics and the envi- ronment. Autoimmunity reviews 9(5), A387–A394 (2010)
12. Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedi- cal image segmentation. In: International Conference on Medical image computing and computer-assisted intervention. pp. 234–241. Springer (2015)
13. Tseng, K.L., Lin, Y.L., Hsu, W., Huang, C.Y.: Joint sequence learning and cross- modality convolution for 3d biomedical segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. pp. 6393–6400 (2017)
14. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, �L., Polosukhin, I.: Attention is all you need. In: Advances in Neural Information Processing Systems. pp. 5998–6008 (2017)
15. Wang, X., Girshick, R., Gupta, A., He, K.: Non-local neural networks. In: Pro- ceedings of the IEEE Conference on Computer Vision and Pattern Recognition. pp. 7794–7803 (2018)