Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

2019·arXiv

ABSTRACT

1 INTRODUCTION

Machine learning models are typically trained to minimize the average loss on a training set, with the goal of achieving high accuracy on an independent and identically distributed (i.i.d.) test set. However, models that are highly accurate on average can still consistently fail on rare and atypical examples (Hovy & Sgaard, 2015; Blodgett et al., 2016; Tatman, 2017; Hashimoto et al., 2018; Duchi et al., 2019). Such models are problematic when they violate equity considerations (Jurgens et al., 2017; Buolamwini & Gebru, 2018) or rely on spurious correlations: misleading heuristics that work for most training examples but do not always hold. For example, in natural language inference (NLI)—determining if two sentences agree or contradict—the presence of negation words like ‘never’ is strongly correlated with contradiction due to artifacts in crowdsourced training data (Gururangan et al., 2018; McCoy et al., 2019). A model that learns this spurious correlation would be accurate on average on an i.i.d. test set but suffer high error on groups of data where the correlation does not hold (e.g., the group of contradictory sentences with no negation words).

To avoid learning models that rely on spurious correlations and therefore suffer high loss on some groups of data, we instead train models to minimize the worst-case loss over groups in the training data. The choice of how to group the training data allows us to use our prior knowledge of spurious correlations, e.g., by grouping together contradictory sentences with no negation words in the NLI example above. This training procedure is an instance of distributionally robust optimization (DRO),

Figure 1: Representative training and test examples for the datasets we consider. The correlation between the label y and the spurious attribute a at training time does not hold at test time.

which optimizes for the worst-case loss over potential test distributions (Ben-Tal et al., 2013; Duchi et al., 2016). Existing work on DRO has focused on models that cannot approach zero training loss, such as generative models (Oren et al., 2019) or convex predictive models with limited capacity (Maurer & Pontil, 2009; Shafieezadeh-Abadeh et al., 2015; Namkoong & Duchi, 2017; Duchi & Namkoong, 2018; Hashimoto et al., 2018).

We study group DRO in the context of overparameterized neural networks in three applications (Figure 1)—natural language inference with the MultiNLI dataset (Williams et al., 2018), facial attribute recognition with CelebA (Liu et al., 2015), and bird photograph recognition with our modified version of the CUB dataset (Wah et al., 2011). The problem with applying DRO to overparameterized models is that if a model achieves zero training loss, then it is optimal on both the worst-case (DRO) and the average training objectives (Zhang et al., 2017; Wen et al., 2014). In the vanishing-training-loss regime, we indeed find that group DRO models do no better than standard models trained to minimize average loss via empirical risk minimization (ERM): both models have high average test accuracies and worst-group training accuracies, but low worst-group test accuracies (Section 3.1). In other words, the generalization gap is small on average but large for the worst group.

In contrast, we show that strongly-regularized group DRO models that do not attain vanishing training loss can significantly outperform both regularized and unregularized ERM models. We consider penalties, early stopping (Section 3.2), and group adjustments that minimize a risk measure which accounts for the differences in generalization gaps between groups (Section 3.3). Across the three applications, regularized group DRO improves worst-case test accuracies by 10–40 percentage points while maintaining high average test accuracies. These results give a new perspective on generalization in neural networks: regularization might not be important for good average performance (e.g., models can “train longer and generalize better” on average (Hoffer et al., 2017)) but it appears important for good worst-case performance.

Finally, to carry out the experiments, we introduce a new stochastic optimizer for group DRO that is stable and scales to large models and datasets. We derive convergence guarantees for our algorithm in the convex case and empirically show that it behaves well in our non-convex models (Section 5).

2 SETUP

Consider predicting labels from input features . Given a model family , loss , and training data drawn from some distribution P, the standard goal is to find a model that minimizes the expected loss under the same distribution P. The standard training procedure for this goal is empirical risk minimization (ERM):

where is the empirical distribution over the training data.

In distributionally robust optimization (DRO) (Ben-Tal et al., 2013; Duchi et al., 2016), we aim instead to minimize the worst-case expected loss over an uncertainty set of distributions Q:

The uncertainty set Q encodes the possible test distributions that we want our model to perform well on. Choosing a general family Q, such as a divergence ball around the training distribution, confers robustness to a wide set of distributional shifts, but can also lead to overly pessimistic models which optimize for implausible worst-case distributions (Duchi et al., 2019).

To construct a realistic set of possible test distributions without being overly conservative, we leverage prior knowledge of spurious correlations to define groups over the training data and then define the uncertainty set Q in terms of these groups. Concretely, we adopt the group DRO setting (Hu et al., 2018; Oren et al., 2019) where the training distribution P is assumed to be a mixture of m groups indexed by G = {1, 2, . . . , m}.1 We define the uncertainty set Q as any mixture of these groups, i.e., , where is the -dimensional probability sim- plex; this choice of Q allows us to learn models that are robust to group shifts. Because the optimum of a linear program is attained at a vertex, the worst-case risk (2) is equivalent to a maximum over the expected loss of each group,

We assume that we know which group each training point comes from—i.e., the training data comprises (x, y, g) triplets—though we do not assume we observe g at test time, so the model cannot use g directly. Instead, we learn a group DRO model minimizing the empirical worst-group risk :

where each group is an empirical distribution over all training points with (or equivalently, a subset of training examples drawn from ). Group DRO learns models with good worst-group training loss across groups. This need not imply good worst-group test loss because of the worst-group generalization gap . We will show that for overparameterized neural networks, is large unless we apply sufficient regularization.

2.1 APPLICATIONS

In the rest of this paper, we study three applications that share a similar structure (Figure 1): each data point (x, y) has some input attribute that is spuriously correlated with the label y, and we use this prior knowledge to form groups, one for each value of (a, y). We expect that models that learn the correlation between a and y in the training data would do poorly on groups for which the correlation does not hold and hence do worse on the worst-group loss .

Object recognition with correlated backgrounds (Waterbirds dataset). Object recognition models can spuriously rely on the image background instead of learning to recognize the actual object (Ribeiro et al., 2016). We study this by constructing a new dataset, Waterbirds, which combines bird photographs from the Caltech-UCSD Birds-200-2011 (CUB) dataset (Wah et al., 2011) with image backgrounds from the Places dataset (Zhou et al., 2017). We label each bird as one of Y = {waterbird, landbird} and place it on one of A = {water background, land background}, with waterbirds (landbirds) more frequently appearing against a water (land) background (Appendix C.1). There are n = 4795 training examples and 56 in the smallest group (waterbirds on land).

Object recognition with correlated demographics (CelebA dataset). Object recognition models (and other ML models more generally) can also learn spurious associations between the label and demographic information like gender and ethnicity (Buolamwini & Gebru, 2018). We examine this on the CelebA celebrity face dataset (Liu et al., 2015), using hair color (Y = {blond, dark}) as the target and gender (A = {male, female}) as the spurious attribute. There are n = 162770 training examples in the CelebA dataset, with 1387 in the smallest group (blond-haired males).

Natural language inference (MultiNLI dataset). In natural language inference, the task is to determine if a given hypothesis is entailed by, neutral with, or contradicts a given premise. Prior work has shown that crowdsourced training datasets for this task have significant annotation artifacts, such as the spurious correlation between contradictions and the presence of the negation words nobody, no, never, and nothing (Gururangan et al., 2018). We divide the MultiNLI dataset (Williams et al., 2018) into m = 6 groups, one for each pair of labels Y = {entailed, neutral, contradictory} and spurious attributes A = {no negation, negation}. There are n = 206175 examples in our training set, with 1521 examples in the smallest group (entailment with negations); see Appendix C.1 for more details on dataset construction and the training/test split.

3 COMPARISON BETWEEN GROUP DRO AND ERM

To study the behavior of group DRO vs. ERM in the overparametrized setting, we fine-tuned ResNet50 models (He et al., 2016) on Waterbirds and CelebA and a BERT model (Devlin et al., 2019) on MultiNLI. These are standard models for image classification and natural language inference which achieve high average test accuracies on their respective tasks.

We train the ERM (1) and group DRO (4) models using standard (minibatch) stochastic gradient descent and (minibatch) stochastic algorithm introduced in Section 5, respectively. We tune the learning rate for ERM and use the same setting for DRO (Appendix C.2). For each model, we measure its average (in-distribution) accuracy over training and test sets drawn from the same distribution, as well as its worst-group accuracy on the worst-performing group.

3.1 ERM AND DRO HAVE POOR WORST-GROUP ACCURACY IN THE OVERPARAMETERIZED REGIME

Overparameterized neural networks can perfectly fit the training data and still generalize well on average (Zhang et al., 2017). We start by showing that these overparameterized models do not generalize well on the worst-case group when they are trained to convergence using standard regularization and hyperparameter settings (He et al., 2016; Devlin et al., 2019), regardless of whether they are trained with ERM or group DRO.2

ERM. As expected, ERM models attain near-perfect worst-group training accuracies of at least 99.9% on all three datasets and also obtain high average test accuracies (97.3%, 94.8%, and 82.5% on Waterbirds, CelebA, and MultiNLI). However, they perform poorly on the worst-case group at test time with worst-group accuracies of 60.0%, 41.1%, and 65.7% respectively (Table 1, Figur