A general framework for ensemble distribution distillation

2020·Arxiv

Abstract

Abstract

Ensembles of neural networks have been shown to give better performance than single networks, both in terms of predictions and uncertainty estimation. Additionally, ensembles allow the uncertainty to be decomposed into aleatoric (data) and epistemic (model) components, giving a more complete picture of the predictive uncertainty. Ensemble distillation is the process of compressing an ensemble into a single model, often resulting in a leaner model that still outperforms the individual ensemble members. Unfortunately, standard distillation erases the natural uncertainty decomposition of the ensemble. We present a general framework for distilling both regression and classifi-cation ensembles in a way that preserves the decomposition. We demonstrate the desired behaviour of our framework and show that its predictive performance is on par with standard distillation.

1 Introduction

Recently, there has been a surge of effort in modelling and estimating the uncertainty in deep neural networks (e.g. Malinin et al., 2019; Kendall and Gal, 2017; Guo et al., 2017; Widmann et al., 2019). For a wide range of applications, from autonomous vehicles to medical image-analysis, reliable uncertainty estimates are vital. One step towards understanding the predictive uncertainty is to decompose the total uncertainty into its different components. Specifically, the total uncertainty in

Figure 1: Schematic view of the general distribu- tion distillation. Here the output data is modelled with . The ensemble produces several plausible predictive distributions (left). The distilled model mimics this by learning a distribution over the parameters that captures the epistemic uncertainty in the model (right).

a prediction is considered to stem from epistemic uncertainty about the true model and inherent, aleatoric, noise in the data. Decomposing the total uncertainty into its epistemic and aleatoric components does not only provide a more complete picture of the uncertainty quantifi-cation, but is also beneficial in applications such as active learning and reinforcement learning. Being able to discriminate between noise and model uncertainty, enables more efficient strategies for choosing new candidates for annotation and actions to explore, respectively.

An ensemble of neural networks consists of several, slightly different networks, that are averaged when making predictions. Using an ensemble usually improves performance and make predictions more robust (Lakshminarayanan et al., 2017). Recent work has shown that ensembles also consistently provide good uncertainty estimates. This was demonstrated in a recent benchmark by Ovadia et al. (2019), where the authors argue that this is due to the the ensemble’s capability to represent model diversity, i.e. epistemic uncertainty. They write:

“Post-hoc calibration [...] leads to well-calibrated uncertainty on i.i.d. test and small values of skew, but is significantly outperformed by methods that take epistemic uncertainty into account as the skew increases.”

The epistemic uncertainty is naturally characterised in an ensemble as the spread of the individual ensemble members. Indeed, since they are trained in an identical manner, disagreement on the distribution of a certain prediction means that the model is uncertain about that prediction.

An obvious drawback of ensembles, however, is their inefficiency. Even though the ensemble members’ predictions can be computed in parallel, there is still a considerable increase in memory usage compared to employing a single model. It is therefore natural to consider some form of model compression to obtain a framework that is more efficient at test time. In the aforementioned benchmark, Ovadia et al. (2019) conclude:

“Reducing the computational and memory costs, while retaining the same performance under dataset shift, would also be a key research challenge.”

The present article is a step towards this important goal.

Ensemble distillation is a compression procedure where a smaller, distilled, network learns to approximate the predictions of a large ensemble of networks. The end result is often a model that is considerably smaller than the full ensemble but still better performing and more robust than a single network trained on the same data (Hinton et al., 2015).

The drawback of standard ensemble distillation (as done, e.g., by Hinton et al. (2015)) is that it only considers the mean prediction of the ensemble, thereby overlooking the spread in the predictions among the individual ensemble members. As a consequence, the information that the ensemble provides in terms of epistemic uncertainty is lost in the distillation process.

In order for the distilled network to also capture the spread of the ensemble, we propose to learn a distribution over the ensemble predictions. Contrary to standard distillation, the distilled network will not perform the same task as the ensemble members. Instead, its training objective is to predict parameters of some distribution which describe the ensemble. A schematic illustration of this is given in fig. 1. A special case of this has recently been proposed by Malinin et al. (2019) for classification problems, using a Dirichlet distribution to model the ensemble predictions.

In this paper we present a general framework that enables ensemble distribution distillation of both classifi-cation and regression models, as well as other prediction tasks. This is accomplished by modelling and distilling the distribution over some intermediate (typically unconstrained) variable of the ensemble networks. Not only is our framework more generally applicable than previous work, but it also allows for greater flexibility in emulating the ensemble in ways that make it easier to represent the epistemic uncertainty. We discuss the differences between our proposed framework and the method by Malinin et al. (2019) in more detail in section 3.2 and section 5.

2 Background

Probabilistic predictive models In a general supervised learning problem with a set of pairs of input and corresponding target output

the objective is to train a model to find some underlying relation between x and y, enabling predictions of the output given future, unseen data .

In this paper, we focus on approximating the true conditional probability distribution of the output given the input, p(y|x, D), with from some family of distributions parameterised by . In this context, is the output of a neural network that maps x to a parameter vector z for q(y; z). For instance, in a classification problem z is typically a probability vector, such that q(y; z) = Cat(y; z). Similarly, while regression problems are often formulated as directly estimating the output y from x, a probabilistic model is obtained by modeling the probability of y|x using some parametric family of distributions. A common choice is to use a normal distribution where z corresponds to the conditional mean and (co-)variance of y|x, i.e., .

The network parameters are optimised in order to maximise the likelihood of data with respect to . In practice we minimise the negative logarithm of the likelihood (NLL),

possibly with some weight regularisation.

Uncertainty quantification The uncertainty in a model’s prediction can be characterised using the estimated conditional probability . However, when reasoning about the uncertainty it is useful to distinguish between epistemic uncertainty that arise from a lack of knowledge about the model and its parameters , and aleatoric uncertainty that results from intrinsic noise in the data (Kiureghian and Ditlevsen, 2009). For instance, we expect our model to have high epistemic uncertainty on a new data point, , originating from outside the intended input distribution. For a fixed (learned) value of , the model will only capture aleatoric uncertainty.

Conceptually, we can address this limitation by taking a Bayesian approach, learning a posterior distribution over the model parameters and expressing the predictive distribution as

More specifically, we can use this approach to define the different types of uncertainty as:

Aleatoric: (4b) Epistemic: (4c)

where I is some uncertainty measure, such as variance, entropy or differential entropy. In the particular case when I[p] is the variance of p, the epistemic uncertainty is given by Varby the law of total variance.

Ensembles Computing the posterior distribution over model parameters is unfortunately intractable in most cases when is given by a deep neural network. Although many approximate Bayesian methods have been proposed (e.g. Blundell et al., 2015; Gal and Ghahramani, 2016) a simple alternative is to use an ensemble of networks. This has been found to have very competitive empirical performance (Ovadia et al., 2019; Lakshminarayanan et al., 2017).

we train M networks independently. Some measures are taken to ensure diversity in the ensemble; random initialisation of the same network architecture and randomly sampled mini-batches are commonly considered enough. This results in M identically distributed models . While ensembles were originally motivated primarily as a way to improve the performance and robustness compared to a single model, they also provide a natural estimate of the epistemic uncertainty. Specifically, we can use the spread of the ensemble,

posterior in (4) to compute the different types of uncertainties. This decomposition of uncertainty with ensembles has been extensiveley explored, e.g., by Kendall and Gal (2017); Lakshminarayanan et al. (2017); Malinin and Gales (2018); Malinin et al. (2019).

Ensemble distillation Ensembles provide powerful representations of the predictive distribution, but can be cumbersome to work with in practice since the memory usage and computational cost at test time scale linearly with M. In such cases model compression can provide a remedy (Bucilu et al., 2006). Distillation refers to a process of compressing the ensemble into a single model (Hinton et al., 2015). The idea is to train a new, distilled, model to mimic the predictions made by the ensemble (in some way), after which the ensemble itself can be discarded. The use of ensemble distillation is most prevalent in classifica-tion, where the ensemble members each pre- dict a probability vector over classes, . The distilled model is also trained as a classifier, but with the average ensemble probability vector as a “soft target”

nation of hard and soft targets). Distillation of regression models has, to the best of our knowledge, received com- paratively little attention.

3 Distribution distillation

In this section we focus on distillation methods that preserve the uncertainty quantification of the ensemble. First we discuss an interpretation of “vanilla” distillation as a KL minimization problem. We then propose a general framework for distilling the distribution over the ensemble in a way that preserves the possibility of uncertainty decomposition.

3.1 Distillation as KL minimization

The approach mentioned above for distilling ensembles of classification models, namely cross-entropy training with respect to soft targets from the ensemble average, is equivalent to interpreting the ensemble as a mixture of categorical distributions and minimize the KLdivergence between the distilled model and the mixture,

Distillation of regression models is not as straightforward. With ensemble members predicting an estimate of some continuous variable y, the most intuitive form of distillation would perhaps be to train on the average prediction . A version more faithful to the classification setting is to let the distilled model output parameters of some distribution. We then minimize the KL-divergence between the mixture of the predictive distributions described by the ensemble and the distilled model. For instance, if both the ensemble members and the distilled model are assumed to be Gaussian, we get

where . An explicit expression for this KL-divergence is given in the supplementary material. Recent works have also used the KL-divergence interpretation, Englesson and Azizpour (2019) for classi-fication and Tran et al. (2020) for both classification and regression.

We call this approach mixture distillation since the distilled network compresses the ensemble mixture distribution with a parametric distribution (e.g., from the same family as those of the individual members). This is a general approach; all ensembles of probabilistic predictive models can be distilled by it.

3.2 A general framework for distribution distillation

The mixture distillation method captures the total uncertainty of the model. However, a clear drawback with using this objective is that information about the spread of the ensemble predictions is lost in the distillation. Therefore, we loose one very compelling property of the ensemble, namely the possibility of decomposing the total predictive uncertainty into epistemic and aleatoric components.

To address this limitation we propose a new framework for distillation that is different from both simple average distillation and mixture distillation. In our formulation the distilled network predicts parameters for a higher-order distribution v, which in turn approximates the distribution over the parameters produced by the ensemble. The distilled network, like the ensemble members, is trained by minimising a NLL, but where we use the output of the ensemble as the target:

where . Note that the expectation is taken w.r.t. the marginal distribution over the inputs, and the samples used to approximate this distribution are possibly different from the labeled training data used in (2). Put differently, the distillation procedure itself does not require labelled data and we can possibly use a larger unlabeled dataset than when training the individual ensemble members, or even out-of-distribution (OOD) data, when performing the distillation. Instead, the predictions made by the ensemble produces pseudo-labels for the distillation procedure.

A key property of the proposed framework is that it is very generic and applies, e.g., to both classification and regression problems. This is in contrast with the recent work by Malinin et al. (2019), who also consider distribution distillation. They propose a method similar to ours, but it is restricted to the classification setting. Specifically, their model builds on so called prior networks (Malinin and Gales, 2018) where a Dirichlet distribution is used to model the spread of the class probability vectors produced by the ensemble.

The generality of our framework is related to the fact that there is a freedom in choosing the parameterisation of the output distribution q(y; z) used for the individual ensemble members. Assumptions about the problem leads to a choice of ensemble member distribution family, which in turn influences the choice of distribution family for the distilled model.

For instance, in a classification problem with K classes we can interpret z as a probability vector, such that q(y; z) = Cat(y; z), as discussed above. This implies that the ensemble members produce outputs on the simplex. A possible choice for the distribution over z would then be the Dirichlet distribution,

that is, we model the probability vectors produced by the ensemble members as Dirichlet distributed. Here, the distilled model outputs the concentration parameters of the Dirichlet distribution. This particular choice of distributions recovers the method proposed by Malinin et al. (2019).

However, for the same classification problem we could instead design the ensemble members to output z as logits, resulting in

for . Here we have used the Kth class as a reference () to obtain a unique parameterisation. This design choice would allow a network output and we can pick to be a distribution with infinite support, e.g. the normal distribution

where the distilled model outputs the mean and (co-)variance of the distribution over logits. Assuming a normal distribution for z implies a logit-normal distribution LN over the resulting class probability vector, in contrast to the Dirichlet distribution used in (8).

The proposed framework enables a large set of possible distributional assumptions for regression problems as well. With being a normal distribution and corresponding to the standard mean-variance-parameterisation of that distribution, we can let the output of the distilled network be the parameters of, say, a normal–inverse–Wishart distribution:

Alternatively, as in the classification case, we can reparameterise the output distribution q(y; z) such that z has unbounded support, and then model it using a normal distribution. As a concrete example, for a one-dimensional regression problem q can be parametrised by

where and are unconstrained parameters (similar to the logits in the classification case), but where is transformed to obtain a positive variance. If both parameters are outputs from the network, i.e., , this results in a heteroscedastic regression model. We can then use a bivariate normal distribution over z for the distillation model,

where . We illustrate this particular setting of the distillation model in fig. 1.

3.3 Predictions and uncertainty quantifica-tion

The proposed distribution distillation framework allows us to learn a higher-order distribution over the parameters of the individual ensemble members’ predictive distributions. The advantage of this is that we end up with a distilled network which not only models the ensemble predictions but also its epistemic uncertainty, encoded in the “spread” of the distribution .

The distilled network can still be used for the original problem of making predictions through the marginal predictive distribution,

which has a closed-form expression for some combinations of modelling distributions. If it does not, we can sample from v (and possibly q) to obtain various estimates derived from the predictive distribution. Even if we need to resort to sampling from v at test time, this is typically a cheap operation, since we only need to propagate x through the distilled network once (compared to M independent propagations for an ensemble). Furthermore, the savings in terms of memory requirements is unaffected by the sampling.

The distilled model can readily be used for computing the total, aleatoric, or epistemic uncertainty. Similarly to (4) we have,

As discussed above, if the involved expectations are intractable we can approximate them by sampling. Let be independent draws from the distilled distribution. Then

The epistemic uncertainty is given by the difference as in (4c).

4 Experiments

We illustrate the feasibility of our proposed framework and evaluate it in both regression and classification settings. It is worth pointing out that the purpose of the distillation is to compress the ensemble to reduce its computational costs and memory usage. We expect that this compression comes at the price of a performance drop. Hence, the purpose of the illustration is not to show that the distilled model outperforms the ensemble, but rather that it has comparable performance at a fraction of the computational cost and memory requirement.

4.1 Regression

Regression is an under-explored topic in distillation. Here, we demonstrate how our framework can be used in that setting. First, we present regression distillation on a toy dataset and then we illustrate its performance on some real-world datasets.

Sparsification plots and AUSE Sparsification plots (Bruhn and Weickert, 2006; Kondermann et al., 2008) visualise the quality of the total uncertainty estimated by

(a) (b) (c)

Figure 2: Mean prediction and uncertainty estimation on heteroscedastic toy data. (a) Ensemble, trained on data on the interval . (b) Distribution distillation (our framework). (c) Mixture distillation. Both distilled networks are trained only on ensemble predictions on x sampled uniformly on . Our framework preserves the aleatoric and epistemic uncertainty, whereas the mixture distillation is only able to represent the total uncertainty.

a regression model that estimates both a regression estimate and a total uncertainty I[p(y|x)].

The regression estimates are ordered from most to least estimated uncertainty, where uncertain estimates are expected to have a larger error. The average error is calculated for a sequence of subsets, where each new subset removes a larger fraction of the most uncertain estimates. Ideally, larger uncertainties should correspond to larger errors (on average) and removing points with the most uncertain predictions should therefore reduce the average error.

To get a comparable score, errors are normalised to one and measured relative to an oracle, which orders the estimates by the actual error. The difference between the oracle and model sparsification is called sparsification error (SE). The area under the SE (AUSE), is a single value measuring the quality of the uncertainty estimates (Ilg et al., 2018).

Regression toy example We demonstrate our distillation framework’s ability to preserve uncertainty decomposition with a one-dimensional toy example. The example data set is the same as used by Gustafsson et al. (2019) and is a sinusoidal curve with heteroscedastic noise

An ensemble with M = 10 members, each member with a single hidden layer, predicts M normal distributions parametrised as in (12). The ensemble is trained on N = 1000 pairs with sampled uni- formly on and evaluated on data sampled uniformly on a larger interval in order to illustrate ensemble behaviour on OOD data.

The ensemble’s aleatoric and epistemic uncertainty is calculated according to (4b) and (4c), respectively, using variance as a measure of uncertainty. Together with the average mean predicion , the aleatoric and epistemic uncertainty are shown in fig. 2.

The ensemble is distilled to a single network which parametrises a normal distribution over the ensemble parameters z. The distilled network has 2 hidden layers with 10 neurons each. Training is done on the ensemble predictions on inputs drawn from and uses no ground truth output values y. The distilled network is evaluated in the same way as the ensemble and the result is shown in fig. 2(b). This toy example gives an indication that our framework successfully distills the ensemble while retaining its rich uncertainty description.

As a comparison, we also train a network with mixture distillation. We use the same architecture as for the distribution distillation above, but optimise the KLdivergence in (6). Figure fig. 2(c) shows the distilled mean and total uncertainty, but the decomposition into aleatoric and epistemic components is no longer available.

Sparsification error plots in fig. 3 confirm that both distilled networks are able to capture the total uncertainty of the ensemble.

UCI data We use the UCI data (Dua and Graff, 2017) and perform an experiment with the same setup as described by Hern´andez-Lobato and Adams (2015).

We distill an ensemble of M = 10 networks. Individual ensemble members have a single hidden layer with 50 hidden neurons. The distilled model has a single hidden layer of 75 neurons and is trained solely on the predictions of the ensemble.

We measure RMSE, NLL and AUSE for both models

Figure 3: Sparsification plots for the toy data set for three models: ensemble, our distilled network, and a mixture distilled network. Note, the SE is based on the models’ individual estimates and cannot be compared quantitatively.

and the results are compared in table 1. Each data set is split into 5 train-test folds for which both models are re-trained and tested. The ensemble consistently outperforms the distilled model, which is expected since the objective for the distillation is to mimic the ensemble. Indeed, it is natural to expect a certain degradation of predictive performance when compressing the ensemble. Still, we see that the distilled model is performing well in all three metrics, with confidence intervals computed over independent replications largely overlapping those of the ensemble.

Table 1: Results on regression benchmark datasets comparing RMSE, NLL and AUSE for the ensemble and our distillation. Lower is better for all three metrics. As expected, the ensemble is better in most cases, but our distilled model follows it rather closely.

4.2 Classification

We evaluate the given framework on the task of classifying images from the CIFAR-10 data set (Krizhevsky, 2009). Based on (9) and (10), we let the distilled model predict parameters of a normal distribution over ensemble logits . We investigate the distilled model’s ability to adapt to the ensemble by comparing mean, variance and uncertainty estimates of the two models. In addition to this, we illustrate our model’s performance on OOD data in terms of accuracy and expected calibration error (see below) by reproducing experiments from the benchmark by Ovadia et al. (2019). We distill ensembles of size M = 10, also from Ovadia et al. (2019).

The distilled model is constructed based on an 18-

layer ResNet architecture (He et al., 2015) 1. The output of the last layer of the network is modified to match the parameters of a Gaussian distribution over the ensemble member logits as in (10). We restrict the network to output a diagonal covariance matrix and parameterise the diagonal elements according to (9). As in (9), we set and let the distilled model predict a distribution over the logits of the remaining classes.

Expected Calibration Error Following Ovadia et al. (2019), we use the expected calibration error (ECE) for assessing the validity of the uncertainty estimates of our model.

ECE evaluates how well the average confidences of the predictive model matches the corresponding accuracy, reflecting how well-calibrated the model is (Guo et al., 2017). Given a model , the ECE is calculated over buckets of the set of observations as

with,

where argminare model predictions and is an identity function. We let be quar- tiles, plus minimum and maximum values.

Comparing logit distributions We investigate how well the distilled model manages to mimic the ensemble by comparing its parameter predictions to the sample

Figure 4: Histograms of the mean (left) and variance (right) given CIFAR-10 in-distribution test data as pre- dicted by our distilled model and over the logits of the ensemble used to train it. The mean and variance are displayed as the average over the classes for each data point.

Figure 5: Histograms of aleatoric (left) and epistemic (right) uncertainty obtained on CIFAR-10 test data with our distilled model and the ensemble used to train it.

mean and variance over the ensemble logits on CIFAR-10 test data. In fig. 4 we display histograms of the average of the mean and variance over classes given a data point x, where the histogram is computed over all data points in the test set.

These histograms can give us an idea of how well the distilled model represents the ensemble. From what can be observed in fig. 4, the histograms over the mean values match very well, indicating that the distilled model manages to capture the mean of the ensemble logits. The variance, which represents the epistemic uncertainty, also has a reasonable fit, although the variances predicted by the distilled model seem to be somewhat overestimated.

Comparing uncertainty estimates A decomposition of uncertainty based on the predictions of the distilled model is performed in order to analyse the model’s ability to capture aleatoric and epistemic uncertainty. For this, we use an entropy measure, and base our calculations on the general uncertainty measures in (15a)-(15b) and (4c), using 100 samples from the distribution over z. In fig. 5, we compare the uncertainty estimates to measures obtained from the ensemble.

Equivalently to the variance in logit space, fig. 4, the

distilled model has a tendency to overestimate the uncertainties. It would not be surprising if the higher entropy of the distilled model in output space is a direct reflection of the overestimation of the variance over .

Naturally, there are several possible explanations to why the distilled model has not fully captured the true spread of the ensemble. In the experiments, we select a similar network architecture for the distilled model as the individual ensemble members. However, it might be the case that predicting both the mean and spread of the ensemble is a more complex task than the original one of predicting class probabilities. Hence, the objective at hand could require a more complex network architecture. Moreover the flexibility of the distilled model is affected by the distribution family of v(z|x). Considering that the given framework does not limit us to the specific design choices made in these experiments, another distribution family might be more suitable for describing the ensemble.

Also relevant for the performance of the distilled model is the size of the ensemble. An ensemble size of 10 gives only a mere 10 samples from the distribution over the logits given a data point x. It could be hypothesised that the sample size is too small to allow for the distilled model to correctly estimate the covariance matrix of a multivariate normal distribution. In addition, the representation might not be good enough to generalise to the full in-distribution data. With this in mind, it would be of interest to investigate the effects of an increasing ensemble size on the models ability to capture the spread of the ensemble. Malinin et al. (2019) notes that increasing ensemble size have quickly diminishing returns for a Dirichlet distribution but it is not clear if this is a general property.

Out-of-distribution detection The distilled model’s performance on out-of-distribution data is evaluated based on experiments first conducted by Ovadia et al. (2019). The experiments are performed on data sets consisting of 16 corruptions applied with five levels of severity to the original CIFAR-10 test images (Hendrycks and Dietterich, 2019). The corruptions include distortions such as changes to the contrast of the images and addition of Gaussian noise (see supplementary material).

We compare the performance of the distilled model to the performance of the models constructed in Ovadia et al. (2019). In addition to this, we train a distilled model with mixture distillation using the same architecture as for the distribution distilled model but adapted to the KL-divergence objective in (5), and add it to the experiments. The accuracy and ECE obtained with each model over all corrupted data sets and over five repeats are displayed in fig. 6.

Figure 6: Model accuracy and ECE across out-of-distributions data consisting of corrupted CIFAR-10 data sets includ- ing 16 different corruptions applied at an intensity scale ranging from 1 to 5. Each box displays minimum, maximum and median together with the first and third quartiles of the accuracy and ECE, respectively, for a given model. Shown is also results on the original CIFAR-10 image test data.

Although it is the aspiration, we can not expect the distilled model to extrapolate its predictions in such a manner as to exhibit equivalent behaviour as the ensemble on data not seen previously. Similarly to what was noted in Malinin et al. (2019), the behaviour of the ensemble might differ outside of the in-distribution data, making need of a more flexible representation of the ensemble than the one chosen. In spite of this, the distilled model seems to perform comparable to the ensemble on OOD data. Not only that, in terms of ECE, it places itself among the best-performing models on the corrupted data. This indeed should serve as a proof-of-concept, indicating that the distilled model learns to capture certain aspects of the ensemble that makes it better at estimating confidence.

We observe that the distilled model does not perform as well as the ensemble in terms of accuracy, especially not on in-distribution data. This indicates that there is a trade-off between cost, in terms of computation and memory, and model performance. This is consistent with what was observed by Ovadia et al. (2019) regarding the correlation between model cost and performance, where they state that the best-performing models also tended to be more computationally costly. Among the best-performing models, we can also find the dropout model that bases it predictions on sampling by applying dropout during test time. We note that while the dropout model requires less memory than the ensemble, it ought still to be more computationally expensive during test time compared to our distilled model, since it requires several forward passes through the network during prediction.

Comparing our model to the mixture distillation model, the latter model has a slightly higher accuracy on the in-distribution test data. This is reasonable since the task of the mixture distillation model is only to match the mean of the ensemble while our model has the additional task to capture the spread of the ensemble. However, the ability gained by the second objective proves valuable when it comes to estimating uncertainty in the predictions. This stresses the importance of also aiming at capturing the epistemic uncertainty of the ensemble.

An additional advantage of the distilled model is that it only requires ensemble predictions as training targets. As a result, once we have the trained ensemble, we can in principle train the distilled model on any data, unlabelled or not, as long as we can model it with the same distribution. This opens up for the possibility of expanding the training data with out-of-distribution data points, which potentially could boost the performance of the distilled model.

5 Discussion and extensions

In this section, we bring up possible extensions to the presented framework and the experiments conducted. Specifically, we discuss the addition of regularisation to the training objective and the choice of parametrisation of the ensemble output in a classification setting.

Loss function alternatives It is desirable to have a tuning mechanism for fine-grained control over the distilled networks performance on the two tasks of making predictions and representing uncertainty. Minimising the NLL in equation (7) has the potential to find an approximation that does both well. Despite that, it has no natural way of including the annotated data that does exist.

The predictive distribution in (14) can be used to design a loss term which takes the labelled data into account

Then the trade-off between the accuracy of the distilled model and its ability to model uncertainty can be controlled by adding this term to the total loss:

with as a tuning parameter. This trade-off between the two tasks of the distilled model can also be used without the requirement of labelled data, with the aim to give the mean predictions from the ensemble a higher importance during training.

Training with out-of-distribution data The defined framework does allow for training of the distilled model using data for which labels are not available. Hence, the performance of the distilled model on out-of-distribution data can be addressed by expansion of the training data set to include data points that are not in-distribution. In practice however, it can be difficult to construct a relevant OOD dataset for the training and we have not considered this option here.

Logits matching vs temperature annealing for clas-sification In the classification setting we have primarily considered parameterising q(y; z) by parametrising z as logits rather than probability vectors. Using a nonstandard parametrisation can seem contrived, but it has potential advantages:

If we instead work with the standard parameterisation of q in terms of the probability vector, then the distillation process can have difficulties in distinguishing small (but potentially important) differences in the class probabilities. This is due to the fact that the soft-max function

squashes the unbounded logits into the interval (0, 1), which can cause predictions to be close on the probability simplex even though they might be well separated in logit space. The problem that distillation methods can have a hard time in capturing small differences in class probabilities has been recognised before. Hinton et al. (2015) propose to use temperature annealing to address this issue. The idea is to rescale the logits before applying the soft-max function, of both the distilled model and the target ensemble,

where is referred to as the temperature of the soft-max.

This tweak of the original distribution is expressly used in order to more easily distinguish differences in logit space, both for ordinary distillation (Hinton et al., 2015) and for distribution distillation (Malinin et al., 2019). Performing the distillation directly in logit space provides another remedy to the same issue, which avoids the need to specify a temperature or annealing schedule. The importance of this flexibility is of interest for further study.

6 Conclusions

We have proposed a general framework for compressing an ensemble of neural networks while still maintaining the rich description of predictive uncertainty that is one of the main advantages of ensembles. Specifically, the compressed model can be used to estimate both epistemic and aleatoric uncertainty. Contrary to previous work, our framework applies to both regression and clas-sification models. We have demonstrated that this framework can result in compressed models with performance that is highly competitive with the state-of-the-art. Furthermore, compared to using a full ensemble, or other methods that are able to capture epistemic uncertainty (e.g. Monte Carlo dropout and variational inference), our distilled model is simple and efficient to use at test time and it has favorable storage cost.

ACKNOWLEDGEMENTS

This work was supported by the Wallenberg Al, Autonomous Systems and Software Program (WASP) funded by the Knut and Alice Wallenberg Foundation.

References

C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wier- stra. Weight uncertainty in neural network. In F. Bach and D. Blei, editors, Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, Jul 2015.

A. Bruhn and J. Weickert. A confidence measure for vari- ational optic flow methods. In Geometric Properties for Incomplete Data, pages 283–298. Jan 2006.

C. Bucilu, R. Caruana, and A. Niculescu Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, 2006.

D. Dua and C. Graff. UCI machine learning repository, 2017. URL http://archive.ics.uci.edu/ml.

E. Englesson and H. Azizpour. Efficient EvaluationTime Uncertainty Estimation by Improved Distillation. arXiv e-prints, art. arXiv:1906.05419, Jun 2019.

Y. Gal and Z. Ghahramani. Dropout as a bayesian ap- proximation: Representing model uncertainty in deep learning. In International conference on machine learning, 2016.

C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. On Calibration of Modern Neural Networks. In Proceedings of the 34th International Conference on Machine Learning, Jun 2017.

F. K. Gustafsson, M. Danelljan, and T. B. Sch¨on. Eval- uating Scalable Bayesian Deep Learning Methods for Robust Computer Vision. arXiv e-prints, art. arXiv:1906.01620, Jun 2019.

K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learn- ing for image recognition. In Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, Dec 2015.

D. Hendrycks and T. Dietterich. Benchmarking neural network robustness to common corruptions and perturbations. In 7th International Conference on Learning Representations, ICLR 2019, 2019.

J. M. Hern´andez-Lobato and R. P. Adams. Probabilistic Backpropagation for Scalable Learning of Bayesian Neural Networks. In International Conference on Machine Learning, 2015.

G. Hinton, O. Vinyals, and J. Dean. Distilling the Knowledge in a Neural Network. arXiv e-prints, art. arXiv:1503.02531, Mar 2015.

E. Ilg, ¨O. C¸ ic¸ek, S. Galesso, A. Klein, O. Makansi, F. Hutter, and T. Brox. Uncertainty Estimates and Multi-Hypotheses Networks for Optical Flow. In Proceedings of the European Conference on Computer Vision, 2018.

A. Kendall and G. Gal. What uncertainties do we need in bayesian deep learning for computer vision? In Advances in neural information processing systems, 2017.

D. P. Kingma and J. Ba Lei. Adam: A method for stochastic optimization. In 3rd International Conference for Learning Representations, 2015.

A. D. Kiureghian and O. Ditlevsen. Aleatory or epis- temic? Does it matter? Structural Safety, 31(2):105– 112, 2009.

C. Kondermann, R. Mester, and C. Garbe. A statisti- cal confidence measure for optical flows. In European Conference on Computer Vision, 2008.

A. Krizhevsky. Learning multiple layers of features from tiny images. 2009.

B. Lakshminarayanan, A. Pritzel, and C. Blundell. Sim- ple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. In Advances in neural information processing systems, 2017.

A. Malinin and M. Gales. Predictive Uncertainty Esti- mation via Prior Networks. In Advances in Neural Information Processing Systems, 2018.

A. Malinin, B. Mlodozeniec, and M. Gales. Ensemble Distribution Distillation. arXiv e-prints, art. arXiv:1905.00076, Apr 2019.

Y. Ovadia, E. Fertig, J. Ren, Z. Nado, D. Sculley, S. Nowozin, J. V. Dillon, B. Lakshminarayanan, and J. Snoek. Can you trust your model’s uncertainty? evaluating predictive uncertainty under dataset shift. In Advances in Neural Information Processing Systems, 2019.

L. Tran, B. S. Veeling, K. Roth, J. Swiatkowski, J. V. Dillon, J. Snoek, S. Mand t, T. Salimans, S. Nowozin, and R. Jenatton. Hydra: Preserving Ensemble Diversity for Model Distillation. arXiv e-prints, art. arXiv:2001.04694, Jan 2020.

D. Widmann, F. Lindsten, and D. Zachariah. Calibration tests in multi-class classification: A unifying framework. In Advances in Neural Information Processing Systems. 2019.

A Mixture distillation

An ensemble can be distilled in such a way that only the estimation of total uncertainty is preserved.

An equally weighted mixture model p is constructed from the ensemble output parameters and the distilled model q is optimised to produce the parameters of a single distribution, similar to the mixture.

The similarity is measured with the KL-divergence. With the expectation in the divergence taken w.r.t. the mixture model then only one term depends on the parameters of q:

A.1 Categorical

For an ensemble with members proposing categorical distributions, the distilled model predicts a categorical distribution that minimises the KLdivergence between it and the categorical mixture. The categorical mixture is represented by the average probability vector, the so called soft-target

where H is the cross-entropy.

A.2 Gaussian

For an ensemble with members proposing gaussian distributions, the distilled model also predicts a gaussian distribution that minimises the KL-divergence between it and the gaussian mixture :

The logarithm of the distilled distribution yields the following terms:

where only the first term depends on y. The denominator in the first term can in turn be expanded to:

With y a stochastic variable distributed according to , the expectation of these terms are

In total the expectation of all the terms is

where the middle term is not dependent on j and the last one sums to 0. Finally:

B Training details

B.1 Regression

In the last layer of the models, the output that predicts the variance parameter is transformed to the positive real axis with

For the UCI data in section 4.1, the training was sensitive to initialisation and occasionally diverged.

B.1.1 Toy example

Each ensemble member is trained for a 150 epochs with batch size 32. We use the Adam optimization algorithm (Kingma and Ba Lei (2015)) with learning rate . The distilled model has two hidden layers with 10 neurons in each. We train it for 30 epochs with the same optimizer as for the ensemble

B.1.2 UCI data

For the ensemble training we use the same setup as in Lakshminarayanan et al. (2017). We use the Adam optimization algorithm (Kingma and Ba Lei (2015)) with learning rate . The distilled model has a single hidden layer with 75 neurons. We train it for 30 epochs with the same optimizer as for the ensemble.

B.2 Classification

The distilled model is trained for 100 epochs using the Adam optimization algorithm (Kingma and Ba Lei (2015)). The learning rate is set as , where is the initial learning rate, k is the step and with c = 0.8. A step is taken every 20th epoch.

The CIFAR-10 image data is scaled to the range [0.0, 1.0] prior to training. Augmentation is used in the form of random flips (horizontal) and random crops.

A similar training regime is used for the mixture distillation model using regular cross-entropy loss with the mean of the ensemble soft-max output as the target.

B.2.1 Numerical stability of variance estimation

For numerical stability during training, we make the network output extra positive constants (with K = 10), c, and parametrise the diagonal elements of the covariance matrix according to:

where u is the untransformed output of the network. During the test phase, to avoid numerical issues, we let if u > 10.

B.2.2 OOD data

The corrupted CIFAR-10 data (Hendrycks and Dietterich (2019)) used for the CIFAR-10 out-of distribution experiments includes the following 16 corruptions

• Brightness

• Contrast

• Defocus blur

• Elastic transform (stretch/contract regions of image)

• Fog

• Frost

• Gaussian blur

• Gaussian noise

• Glass blur

• Impulse noise (”salt-and-pepper” noise, colour analogue)

• Pixelate

• Saturate

• Shot noise (Poisson noise)

• Spatter

• Speckle noise

• Zoom blur

The corruptions are applied to the CIFAR-10 test data set of 10,000 data points on a severity scale ranging from 1 to 5.