The typical goal in machine learning is to minimize the average error on a test set that is independent and identically distributed (i.i.d.) to the training set. A large body of prior work has shown that overparameterization—increasing model size beyond the point of zero training error—improves average test error in a variety of settings, both empirically (with neural networks, e.g., Nakkiran et al. (2019)) and theoretically (with linear and random projection models, e.g., Belkin et al. (2019); Mei & Montanari (2019)).
However, recent work has also demonstrated that models with low average error can still fail on particular groups of
*Equal contribution 1Stanford University. Correspondence to: Shiori Sagawa <ssagawa@cs.stanford.edu>, Aditi Raghunathan <aditir@stanford.edu>, Pang Wei Koh <pangwei@cs.stanford.edu>.
Proceedings of the International Conference on Machine Learning, Online, PMLR 119, 2020. Copyright 2020 by the author(s).
Figure 1. Top: Overparameterization hurts test error on the worst group when models are trained with the reweighted objective that upweights minority groups (Equation 3). Without reweighting, models have poor worst-group error regardless of model size (Appendix A.1). Bottom: Consider data points comprises a core feature
(x-axis) and a spurious feature
(y-axis). The label y is highly correlated with
, except on two minority groups (crosses). Underparameterized models use the core feature (left), but overparameterized models use the spurious feature and memorize the minority points (right).
data points (Blodgett et al., 2016; Hashimoto et al., 2018; Buolamwini & Gebru, 2018). This problem of high worst-group error arises especially in the presence of spurious correlations, such as strong associations between label and background in image classification (McCoy et al., 2019; Sagawa et al., 2020). To mitigate this problem, common approaches reduce the worst-group training loss, e.g., through distributionally robust optimization (DRO) or simply upweighting the minority groups. Sagawa et al. (2020) showed these approaches improve worst-group error on strongly regularized neural networks but fail to help standard neural networks that can achieve zero training error, suggesting that increasing model capacity by reducing regularization—and perhaps by increasing overparameterization as well—can exacerbate spurious correlations.
In this paper, we investigate why overparameterization exacerbates spurious correlations under the above approach of upweighting minority groups. We first confirm on two
Figure 2. We consider two image datasets, CelebA and Waterbirds, where the label y is correlated with a spurious attribute a in a majority of the training data. The % beside each group shows its frequency in the training data. To measure how robust a model is to the spurious attribute, we divide the data into groups based on (y, a) and record the highest error incurred by a group. Figure adapted from Sagawa et al. (2020).
image datasets (Figure 2) that directly increasing overparameterization (i.e., increasing model size) indeed hurts worst-group error, leading to models that are highly inaccurate on the minority groups where the spurious correlation does not hold (Section 3). In contrast, their underparameterized counterparts obtain much better worst-group error, but do worse on average. We also confirm that models trained via empirical risk minimization (i.e., without upweighting the minority) have poor worst-group test error regardless of whether they are under- or overparameterized. Through simulations on a synthetic setting, we further identify two properties of the training data that modulate the effect of overparameterization: (i) the relative sizes of the majority versus minority groups, and (ii) how informative the spurious features are relative to the core features (Section 4).
Why does overparameterization exacerbate spurious correlations? Underparameterized models do not rely on spurious features because that would incur high training error on the (upweighted) minority groups where the spurious correlation does not hold. In contrast, overparameterized models can always obtain zero training error by memorizing training examples, and instead rely on their inductive bias to pick a solution—which features to use and which examples to memorize—out of all solutions with zero training error. Our results suggest an intuitive story of why overparameterization can hurt: because overparameterized models can have an inductive bias towards “memorizing” fewer examples (Figure 1). If (i) the majority groups are sufficiently large and (ii) the spurious features are more informative than the core features for these groups, then overparameterized models could choose to use the spurious features because it entails less memorization, and therefore suffer high worst- group test error. We test this intuition through simulations and formalize it in a theoretical analysis (Section 5).
Our analysis also leads to the counterintuitive result that on overparameterized models, subsampling the majority groups is much more effective at improving worst-group error than upweighting the minority groups. Indeed, an overparameterized model trained on a subset of <5% of the data performs similarly (on average and on the worst group) to an underparameterized model trained on all the data (Section 6). This suggests a possible tension between using overparameterized models and using all the data; average error benefits from both, but improving worst-group error seems to rely on using only one but not both.
Spurious correlation setup. We adopt the setting studied in Sagawa et al. (2020), where each example comprises the input features x, a label (core attribute) , and a spurious attribute
. Each example belongs to a group
, where g = (y, a). Importantly, the spurious attribute a is correlated with the label y in the training set. We focus on the binary setting in which
Applications. We study two image classification tasks (Figure 2). In the first task, the label is spuriously correlated with demographics: specifically, we use the CelebA dataset (Liu et al., 2015) to classify hair color between the labels Y = {blonde, non-blonde}, which are correlated with the gender A = {female, male}. In the second task, the label is spuriously correlated with image background. We use the Waterbirds dataset (based on datasets from Wah et al. (2011); Zhou et al. (2017) and modified by Sagawa et al. (2020)) to classify between the labels Y = {waterbird, landbird}, which are spuriously correlated with the image background A = {water background, land background}. See Appendix A.5 for more dataset details.
Objectives and metrics. We evaluate a model w by its worst-group error,
where is the 0-1 loss. In other words, we measure the error (% of examples that are incorrectly labeled) in each group, and then record the highest error across all groups. The standard approach to training models is empirical risk minimization (ERM): given a loss function
, find the model w that minimizes the average training loss
However, in line with Sagawa et al. (2020), we find that models trained via ERM have poor worst-group test error regardless of whether they are under- or overparameterized (Appendix A.1). To achieve low worst-group test error, prior work proposed modified objectives that focus on the worst-group loss, such as group distributionally robust optimization (group DRO) which directly optimizes for the worst-group training loss (Hu et al., 2018; Sagawa et al., 2020) or reweighting (Shimodaira, 2000; Byrd & Lipton, 2019). Sagawa et al. (2020) showed that both approaches can help worst-group loss, though group DRO is typically more effective. For simplicity, we focus on the well-studied reweighting approach, which optimizes
where is the fraction of training examples in group g. The intuition behind reweighting is that it makes each group contribute the same weight to the training objective: that is, minority groups are upweighted, while majority groups are downweighted. Note that this approach requires the groups g to be specified at training time, though not at test time.
Sagawa et al. (2020) observed that decreasing regularization hurts worst-group error. Though increasing overparameterization and reducing regularization can have different effects (Zhang et al., 2017; Mei & Montanari, 2019), this suggests that overparameterization might similarly exacerbate spurious correlations. Here, we show that directly increasing overparameterization (model size) indeed hurts worst-group error even though it improves average error.
Models. We study the CelebA and Waterbirds datasets described above. For CelebA, we train a ResNet10 model (He et al., 2016), varying model size by increasing the network width from 1 to 96, as in Nakkiran et al. (2019). For Waterbirds, we use logistic regression over random projections, as in Mei & Montanari (2019). Specifically, let the input features, which we obtain by passing the input image through a pre-trained, fixed ResNet-18 model. We train an unregularized logistic regression model over the feature representation ReLU
, where
is a random matrix with each row sampled uniformly from the unit sphere
. We vary model size by increasing the number of projections m from 1 to 10,000. We train each model by minimizing the reweighted objective (Equation (3)). For more details, see Appendix A.5.
Results. Overparameterization improves average test error across both datasets, in line with prior work (Belkin et al., 2019; Nakkiran et al., 2019) (Figure 3). However, in stark contrast, overparameterization hurts worst-group error: the best worst-group test error is achieved by an underparameterized model with non-zero training error. On CelebA, the smallest model (width 1) has 12.4% worst-group training error but comparatively low worst-group test error of 25.6%. As width increases, training error goes to zero but worst-group test error gets worse, reaching >60% for overparameterized models with zero training error. Similarly, on Waterbirds, an underparameterized model with 90 random features and worst-group training error of 17.7% obtains the best worst-group test error of 26.6%, while overparameterized models with zero training error yield worst-group test error of 42.4% at best.
In Appendix A.2, we also confirm that stronger regularization improves worst-group error but hurts average error in overparameterized models, while it has little effect on both worst-group and average error in underparameterized models. However, we focus on understanding the effect of overparameterization in the remainder of the paper.
Discussion. Why does overparameterization hurt worst-group test error? We make two observations. First, in the overparameterized regime, the smallest groups incur the highest test error (blonde males in CelebA and waterbirds on land background in Waterbirds), despite having zero training error. In other words, overparameterized models perfectly fit the minority points at training time, but seem to do so by using patterns that do not generalize. We informally refer to this behavior as “memorizing” the minority points.
Second, underparameterized models do obtain low worst-group error by learning patterns that generalize to both majority and minority groups. Therefore, overparameterized models should also be able to learn these patterns while attaining zero training error (e.g., by memorizing the training points that the underparameterized model cannot fit). Despite this, overparameterized models seem to learn patterns that generalize well on the majority but do not work on the minority (such as the spurious attributes a in Figure 2).
What makes overparameterized models memorize the minority instead of learning patterns that generalize well on both majority and minority groups? We study this question in the next two sections: in Section 4, we use simulations to understand properties of the data distribution that give rise to this trend, and in Section 5 we analyze a simplified linear setting and show how the inductive bias of models towards memorizing fewer points can lead to overparameterized models choosing to use spurious correlations.
The discussion in Section 3 suggests two properties of the training distribution that modulate the effect of overparameterization on worst-group error. Intuitively, overparameterized models should be more incentivized to use the spurious features and memorize the minority groups if (i) the proportion of the majority group, , is higher, and (ii) the ratio
Figure 3. Increasing overparameterization (i.e., increasing model size) hurts the worst-group test error even though it improves the average test error. Here, we show results for models trained on the reweighted objective for CelebA (left) and Waterbirds (right).
of how informative the spurious features are relative to the core features, , is higher. In this section, we use simulations to confirm these intuitions and probe how
and
affect worst-group error in overparameterized models.
4.1. Synthetic experiment setup
Data distribution. We construct a synthetic dataset that replicates the empirical trends in Section 3. As in Section 2, the label is spuriously correlated with a spurious attribute
. We divide our training data into four groups accordingly: two majority groups with a = y, each of size
, and two minority groups with
each of size
. We define
as the total number of training points, and
as the fraction of majority examples. The higher
is, the more strongly a is correlated with y in the training data.
Each (y, a) group has its own distribution over input features comprising core features
generated from the label/core attribute y, and spurious features
generated from the spurious attribute a:
The core and spurious features are both noisy and encode their respective attributes at different signal-to-noise ratios. We define the spurious-core information ratio (SCR) as . The higher the SCR, the more signal there is about the spurious attribute in the spurious features, relative to the signal about the label in the core features.
Compared to the image datasets we studied in Section 3, this synthetic dataset offers two key simplifications. First, the only differences between groups stem from their differences in (y, a), which isolates the effect of flipping the spurious attribute a. In contrast, in real datasets, groups can differ in other ways, e.g., more label noise in one group. Second, the relative difficulty of estimating y versus a is completely governed by changing . In contrast, real datasets have additional complications, e.g., estimating y might involve a more complex function of the input x than
Figure 4. Overparameterization hurts worst-group test error but improves average test error on synthetic data, reproducing the trends we observe in real data.
Figure 5. Overparameterized models have poor worst-group performance on the synthetic data because they rely on spurious features. Left: removing the spurious feature (green) eliminates the detrimental effect of overparameterization. Right: overparamerized models do well on the majority groups where the spurious features match the label, but poorly on the minority groups.
estimating a, and there might be an inductive bias towards learning a simpler model over a more complex one.
In all of the experiments below, we fix the total number of training points n to 3000, and set d = 100 (so each input x has 2d = 200 dimensions). Unless otherwise specified, we set the majority fraction and the noise levels
to encourage the model to use the spurious features over the core features.
Model. To avoid the complexities of optimizing neural networks, we follow the same random features setup we used for Waterbirds in Section 3: unregularized logistic regression using the reweighted objective on the random feature representation ReLU, where
is a random matrix (Mei & Montanari, 2019).
4.2. Observations on synthetic dataset
The synthetic dataset replicates the trends we observe on real datasets. Figure 4 shows how average and worst-group error change with the number of parameters/random projections m. This matches the trends we obtained on CelebA and Waterbirds in Section 3. The best worst-group test error of 28.5% is achieved by an underparameterized model, whereas highly overparameterized models achieve high worst-group test error that plateaus at around 55%. In contrast, the average test error is better for overparameterized models than for underparameterized models.
Overparameterized models use spurious features. Fig-
ure 5-Right shows that overparameterized models have high test error on minority groups () despite zero training error, but perform very well on the majority groups (a = y). Since the only difference between the minority and majority groups in the synthetic dataset is the relative signs of the core and spurious attributes, this suggests overparameterized models are using spurious features and simply memorizing the minority groups to get zero training error, consistent with our discussion in Section 3. In contrast, the underparameterized model has low training and test errors across all groups, suggesting that it relies mainly on core features.
These results imply that the degradation in the worst-group test error is due to the spurious features. We confirm that overparameterization no longer hurts when we “remove” the spurious features by replacing them with noise centered around zero (i.e., we replace the mean of by 0). In this case, the best worst-group test error is now obtained by an overparameterized model, as shown in Figure 5-Left.
4.3. Distributional properties
What properties of the training data make overparameterization hurt worst-group error? We study (i) , which controls the relative size of majority to minority groups, and (ii)
, the relative informativeness of spurious to core features. In the synthetic dataset, overparameterization hurts worst-group test error only when both are sufficiently high. In contrast, overparameterization helps average test error regardless; see Appendix A.3.
Effect of the majority fraction We observe that increasing
, which controls the relative size of the majority versus minority groups, makes overparameterization hurt worst-group error more (Figure 6). When the groups are perfectly balanced with
, overparameterization no longer hurts the worst-group test error, with overparameterized models achieving better worst-group test error than all underparameterized models. This suggests that group imbalance can be a key factor inducing the detrimental effect of overparameterization.
Effect of the spurious-core information ratio Next, we characterize the effect of
, which mea- sures the relative informativeness of the spurious versus core features. A high
means that the spurious features are more informative. We vary
by changing
while keep- ing
fixed, since this does not change the best possible worst-group test error (with a model that uses only the core features
shows that the higher
the more overparameterization hurts. As
increases, the spurious features become more informative, and overparameterized models rely more on them than the core features; underparameterized models outperform overparameterized models only for sufficiently large
. Note that increasing
does not significantly affect the worst-group
Figure 6. The higher the majority fraction and the spurious-core information ratio
, the more overparameterization hurts the worst-group test error. With sufficiently low
overparameterization switches to helping worst-group test error.
test error in the underparameterized regime, since the core features are unaffected. In contrast, increasing the majority fraction
hurts the worst-group test error in both underparameterized and overparameterized models.
4.4. An intuitive story
We return to the question of what makes overparameterized models memorize the minority instead of learning patterns that generalize on both majority and minority groups. The simulation results above show that of all overparameterized models that achieve zero training error, the inductive bias of the model class and training algorithm favors models that use spurious features which generalize only for the majority groups, instead of learning to use core features that also generalize well on the minority groups.
What is the nature of this inductive bias? Consider a model that predicts the label y by returning its estimate of the spurious attribute , taking advantage of the fact that y and a are correlated in the training data. To get achieve zero training error, it will need to memorize the points in the minority group, e.g., by exploiting variations due to noise in the features x. On the other hand, consider a model that predicts y by returning a direct estimate of y based on the core features
. Because
provides a noisier estimate of y than
does for a, this model will need to memorize all points for which
gives an inaccurate prediction of y due to noise. Since the estimators of the core and spurious attributes are equally easy to learn, the main difference between these two models is the number of examples to be memorized.
We therefore hypothesize that the inductive bias favors memorizing as few points as possible. This is consistent with the results above: the model uses and memorizes the minority points only when the fraction of minority points is small (high majority fraction
). Similarly, the model uses
over
to fit the majority points only when the spurious features are less noisy (high
) and therefore require less memorization to obtain zero training error than the core features. In the next section, we make this intuition formal by analyzing a related but simpler linear setting.
In this section, we show how the inductive bias against memorization leads to overparameterization exacerbating spurious correlations. Our analysis explicates the effect of the inductive bias and the importance of the data parameters discussed in Section 4.
The synthetic setting discussed in Section 4 is difficult to analyze because of the non-linear random projections, so we introduce a linear explicit-memorization setting that allows us to precisely define the concept of memorization. For clarity, we refer to the previous synthetic setting in Section 4 as the implicit-memorization setting. In Appendix A.4, we show empirically that models in these two settings behave similarly in the overparameterized regime, though they differ in the underparameterized regime.
In the previous implicit-memorization setting, we varied model size and memorization capacity by varying the number of random projections of the input. In the new explicit-memorization setting, we instead use linear models that act directly on the input and introduce explicit “noise features” that can be used to memorize. We vary the memorization capacity by varying the number of explicit noise features.
5.1. Explicit-memorization setup
Training data. We consider input features x = , where the core feature
and the spurious feature
are scalars. As in the implicit-memorization setup, they are generated based on the label and the spurious attribute, respectively:
The “noise” features are generated as
where is a constant. The scaling by 1/N ensures that for large N, the norm of the noise vectors
is approximately constant with high probability. Intuitively, when N is large, overparameterized models can use
to fit a training point x without affecting its predictions on other points, thereby memorizing x. We formalize this notion of memorization later in Section 5.2.
As before, the training data is composed of four groups, each corresponding to a combination of the label and the spurious attribute
: two majority groups with a = y, each of size
, and two minority groups with
, each of size
. Combined, there are n training examples
Model. We study unregularized logistic regression on the input features . As before, we consider the reweighted estimator
. When the training data is linearly separable, the minimizer of the unregularized logistic loss on the training data is not well-defined. We therefore define
in terms of the sequence of
-regularized models
where is the logistic loss and
is the fraction of training examples in group g. Since scaling a model does not affect its 0-1 error, we define
as the limit of this sequence, scaled to unit norm, as the regularization strength
In the underparameterized regime, the training data is not linearly separable and we simply have In the overparameterized regime where
, the training data is linearly separable, and Rosset et al. (2004) showed that
is the max-margin classifier
The equivalence holds regardless of the reweighting by
: if we define the ERM estimator
analogously to (5) without the reweighting, it is also equal to
. We will therefore analyze
in the overparameterized regime since it subsumes both
We also note that if we use gradient descent to directly optimize the unregularized logistic regression objective (either reweighted or not), the resulting solution after scaling to unit norm also converges to as the number of gradient steps goes to infinity (Soudry et al., 2018).
5.2. Analysis of worst-group error
We now state our main analytical result: in the explicit-memorization setting, the worst-group test error of a suffi-ciently overparameterized model is greater than 1/2 (worse than random) under certain settings of In contrast, underparameterized models attain reasonable worst-group error even under such a setting.
Theorem 1. For any
,
and
, there exists
such that for all
(overparameterized regime), with high probability over draws of the data,
where is the max-margin classifier.
However, for N = 0 (underparameterized regime), with , and
, and in the asymptotic regime with
where minimizes the reweighted logistic loss.
The result in the overparameterized regime applies to the max-margin classifier , which as discussed above subsumes both
when the data is linearly separable. The proof of Theorem 1 appears in Appendix B.
The conditions on and
in Theorem 1 above im- ply high spurious-core information ratio
. Theorem 1 therefore provides a setting where high
and high
provably make overparameterized models obtain high worst-group error, matching the trends we observed upon varying
in the implicit-memorization setting (Figure 6). Furthermore, underparameterized models obtain reasonable worst-group error despite these conditions, mirroring the observations in earlier sections.
5.3. Overparameterization and memorization
We now sketch the key ideas in the proof of Theorem 1 (full proof in Appendix B), focusing first on the overparameterized regime. We start by establishing an inductive bias towards learning the minimum-norm model that fits the training data. We then define memorization and show how the minimum-norm inductive bias translates into a bias against memorization. Finally, we illustrate how the bias against memorization leads to learning the spurious feature and suffering high worst-group error.
Minimum-norm inductive bias. Define a separator as any model that correctly classifies all of the training points (x, y) with margin . Then from standard duality arguments,
can be rewritten as
, the scaled version of the minimum-norm separator
Since scaling does not affect the 0-1 test error, it suffices to analyze . Equation (9) shows that out of the set of all separators (which all perfectly fit the training data), the inductive bias favors the separator with the minimum norm. We now discuss how this minimum-norm inductive bias favors less memorization.
Memorization. For convenience, we denote the three components of a model w as
where , and
. By the representer theorem, we can decompose
as follows:
In the overparameterized regime when , a model can “memorize” a training point
via
, in particular by putting a large weight
in the direction of
(Equation (11)):
Definition 1 (-memorization). A model w memorizes a point
for some constant
Because the noise vectors of the training points (high-dimensional Gaussians) are nearly orthogonal for large N, the component affects the prediction on
, but not on any other training or test points.
This ability to memorize plays a crucial role in making overparameterized models obtain high worst-group error. Intuitively, the minimum-norm inductive bias favors less memorization in overparameterized models. Roughly speaking, models that memorize more have larger weights on the noise vectors
. Since these noise vectors are nearly orthogonal and have similar norm, this translates into a larger norm
Comparing using versus using
To illustrate how the inductive bias against memorization leads to high worst-group error, we consider two extreme sets of separators: (i) ones that use the spurious feature but not the core feature, denoted by
(ii) ones that use the core feature but not the spurious feature, denoted by
In scenario (i), using the spurious feature alone allows models to fit the majority groups very well. Thus, models that use
only need to memorize the minority points. In Proposition 1, we construct a separator
and show that its norm only scales with the number of minority points
Conversely, in scenario (ii), using the core feature alone allows models to fit all groups equally well. However, when
is high,
is noisier than
, so models that use
still need to memorize a constant fraction of all the training points. In Proposition 2, we show that norms of all separators
are lower bounded by a quantity linear in the total number of training points n.
When the majority fraction is sufficiently large such that
, the separator
that uses
will have a lower norm than any separator
that uses
. Since the inductive bias favors the minimum-norm separator, it prefers a separator
that memorizes the minority points and suffers high worst-group error over any
Proposition 1 (Norm of models using the spurious feature). When satisfy the conditions in Theorem 1, there exists
such that for all
, with high probability,
there exists a separator
for some constants
Proof sketch. To simplify exposition in this sketch, suppose that the noise vectors are orthogonal and have con- stant norm
. We construct a separator
that does not use the core feature
as follows. Set
for some large enough constant
. This is sufficient to satisfy the margin condition on the majority points: since
is very small, w.h.p. all majority training points satisfy
However, for the minority training points, the spurious attribute a does not match the label y, and in order to satisfy the margin condition with a positive , these
minority points have to be memorized. Since
is very small, the decrease in the margin due to
is at most
w.h.p. for some constant
that depends on
. To satisfy the margin condition, it thus suffices to set
, and the bound on the norm follows. The full proof appears in Section B.2.6.
Proposition 2 (Norm of models using the core feature). When satisfy the conditions in Theorem 1 and
, there exists
such that for all
, with high probability, all separators
satisfy
for some constant
Proof sketch. Any model has
by definition. We show that a constant fraction of training points have to be
-memorized in order to satisfy the margin condition. We do so by first showing that the probability that a training point x satisfies the margin condition without being
-memorized cannot be too large. For simplicity, suppose again that the noise vectors
are orthogonal and have constant norm
. Then this probability is
for small
, where
is the Gaussian CDF. Hence, in expectation, at least a constant fraction of points from the training distribution need to be memorized in order for
to satisfy the margin condition. With high probability, this is also true on the training set consisting of n points (via the DKW inequality) and the bound on the norm follows. The full proof appears in Section B.2.7.
In the full proof of Theorem 1 in Appendix B, we generalize the above ideas to consider all separators in instead of just the separators in
. Note the importance of both
is high, models that use
only need to memorize the minority groups (Proposition 1), and when
is also high, these models end up memorizing fewer points than models that use
and have to memorize a constant fraction of the entire training set (Proposition 2).
Our results above highlight the role of the majority fraction in determining if overparameterization hurts worst-group test error. When
is large, the inductive bias favors using spurious features because it entails memorizing only a relatively small number of minority points, while the alternative of using core features requires memorizing a large number of majority points. This suggests that reducing the memorization cost of using core features by directly removing some majority points could induce overparameterized models to obtain low worst-group error.
Here, we show that this approach of subsampling the majority group achieves good worst-group test error on the datasets studied above. Subsampling creates a new group-balanced dataset by randomly removing training points in all other groups to match the number of points from the smallest group (Japkowicz & Stephen, 2002; Haixiang et al., 2017; Buda et al., 2018). We then train a model to minimize the average loss on this subsampled dataset. For a precise description, see Appendix A.6.
Figure 7 shows that overparameterized models trained via subsampling (Equation 15) obtain low worst-group error on the CelebA, Waterbirds, and synthetic (implicit-memorization) datasets. Across all three datasets, training via subsampling makes increasing overparameterization help both average and worst-group test error. Moreover, overparameterized models trained on subsampled data are comparable to or better than the best models trained on the full dataset (i.e., underparameterized models trained with reweighting).
Figure 7. Overparameterization helps worst-group test error when training via subsampling, which involves creating a group-balanced dataset by reducing the number of majority points and minimizing average training loss on the new dataset.
Subsampling seems wasteful since it throws away a large fraction of the training data: we only use 3.4% of the full training data for CelebA, 4.6% for Waterbirds, and 10% for the synthetic dataset. However, the results above show that subsampling in overparameterized models matches or outperforms reweighting with underparameterized models. For example, on CelebA, an overparameterized model trained via subsampling obtains 11.1% average test and 15.1% worst-group test error, whereas an underparameterized model trained with reweighting obtains 11.3% average and 25.6% worst-group test error.
Subsampling vs. reweighting. Both subsampling and reweighting artificially balance the groups in the training data, and previous work on imbalanced datasets has concluded that reweighting is typically at least as effective as subsampling (Buda et al., 2018). However, we find a clear difference between subsampling and reweighting in the overparameterized regime: increasing overparameterization with reweighting increases worst-group error, while doing so with subsampling decreases worst-group error. The intuition developed in Sections 4 and 5 shed some light on this difference. Consider an overparameterized model: as in Section 5.1, reweighting does not change the learned model which is the max-margin classifier. However, subsampling reduces . Recall that the inductive bias favors spurious features when the alternative of using core features requires memorizing a large number of training points. By reducing
, we reduce this memorization cost associated with core features, thereby inducing the model to use core features and achieve low worst-group test error.
The effect of overparameterization. The effect of overparameterization on average test error has been widely studied. In what is commonly referred to as “double descent”, increasing model size beyond zero training error decreases test error, despite conventional wisdom that overfitting should increase test error. This behavior has been observed empirically (Belkin et al., 2019; Opper, 1995; Advani & Saxe, 2017; Nakkiran et al., 2019) and shown analytically in high-dimensional regression (Hastie et al., 2019; Bartlett et al., 2019; Mei & Montanari, 2019). These works focus on average test error and are consistent with our findings there. However, our focus is on worst-group test error, particularly when the groups are defined based on spurious attributes, and in this paper we establish that worst-group test error can behave quite differently from average test error.
Increasing overparameterization can actually improve model robustness to some types of distributional shifts (Hendrycks et al., 2019; Hendrycks & Dietterich, 2019; Yang et al., 2020). In this light, our results show that the effect of overparameterization on model robustness can depend heavily on the dataset (e.g., properties like and
), type of distributional shift, and training procedure.
Worst-group error. Prior work on improving worst-group error focused on the underparameterized regime, with methods based on weighting/sampling (Shimodaira, 2000; Jap- kowicz & Stephen, 2002; Buda et al., 2018; Cui et al., 2019), distributionally robust optimization (DRO) (Ben-Tal et al., 2013; Namkoong & Duchi, 2017; Oren et al., 2019), and fair algorithms (Dwork et al., 2012; Hardt et al., 2016; Klein- berg et al., 2017). Our focus is on the overparameterized, zero-training-error regime; here, previous methods based on reweighting and DRO are ineffective (Wen et al., 2014; Byrd & Lipton, 2019; Sagawa et al., 2020). As mentioned in Section 1, Sagawa et al. (2020) demonstrated that stronger -regularization can improve worst-group error on neural networks (when coupled with reweighting or group DRO). Similarly Cao et al. (2019) show that data-dependent regularization can improve error on rare labels. While their work focuses on developing methods to improve worst-group error, our focus is on understanding the mechanisms by which overparameterization hurts worst-group error.
Our work shows that overparameterization hurts worst-group error on real datasets that contain spurious correlations. We studied the implicit- and explicit-memorization settings to provide a potential story for why this might occur: there can be an inductive bias towards solutions that do not need to memorize as many training points, and this can favor models that exploit the spurious correlations.
However, our synthetic settings make several simplifying assumptions, e.g., they suppose that the model prefers the spurious feature because it is less noisy than the core feature. This assumption need not always apply, and different assumptions might also lead to overparameterization exacerbating spurious correlations. For example, there might exist a true classifier based on the core features which has high accuracy but which is relatively more complex (e.g., high parameter norm) and therefore not favored by the training procedure. Studying the effect of overparameterization in settings such as those is important future work.
We also observed that subsampling allows overparameterized models to achieve low average and worst-group test error, despite eliminating a large fraction of training examples. In contrast, when using the full training data, only underparameterized models attain low worst-group test error under our current training methods. These observations call for future work to develop methods that can exploit both the statistical information in the full training data as well as the expressivity of overparameterized models, so as to attain good worst-group and average test error.
We are grateful to Yair Carmon, John Duchi, Tatsunori Hashimoto, Ananya Kumar, Yiping Lu, Tengyu Ma, and Jacob Steinhardt for helpful discussions and suggestions. SS was supported by a Stanford Graduate Fellowship, AR was supported by a Google PhD Fellowship and Open Philanthropy Project AI Fellowship, and PWK was supported by the Facebook Fellowship Program.
Reproducibiltity
Code is available at https://github. com/ssagawa/overparam_spur_corr.
Advani, M. S. and Saxe, A. M. High-dimensional dynamics of generalization error in neural networks. arXiv preprint arXiv:1710.03667, 2017.
Bartlett, P. L., Long, P. M., Lugosi, G., and Tsigler, A. Benign overfitting in linear regression. arXiv, 2019.
Belkin, M., Hsu, D., Ma, S., and Mandal, S. Reconciling modern machine-learning practice and the classical bias– variance trade-off. Science, 116(32), 2019.
Ben-Tal, A., den Hertog, D., Waegenaere, A. D., Melenberg, B., and Rennen, G. Robust solutions of optimization problems affected by uncertain probabilities. Management Science, 59:341–357, 2013.
Blodgett, S. L., Green, L., and O’Connor, B. Demographic dialectal variation in social media: A case study of African-American English. In Empirical Methods in Natural Language Processing (EMNLP), pp. 1119–1130, 2016.
Buda, M., Maki, A., and Mazurowski, M. A. A systematic study of the class imbalance problem in convolutional neural networks. Neural Networks, 106:249–259, 2018.
Buolamwini, J. and Gebru, T. Gender shades: Intersectional accuracy disparities in commercial gender classification. In Conference on Fairness, Accountability and Transparency, pp. 77–91, 2018.
Byrd, J. and Lipton, Z. What is the effect of importance weighting in deep learning? In International Conference on Machine Learning (ICML), pp. 872–881, 2019.
Cao, K., Wei, C., Gaidon, A., Arechiga, N., and Ma, T. Learning imbalanced datasets with label-distribution-aware margin loss. In Advances in Neural Information Processing Systems (NeurIPS), 2019.
Cui, Y., Jia, M., Lin, T., Song, Y., and Belongie, S. Class- balanced loss based on effective number of samples. In Computer Vision and Pattern Recognition (CVPR), pp. 9268–9277, 2019.
Dwork, C., Hardt, M., Pitassi, T., Reingold, O., and Zemel, R. Fairness through awareness. In Innovations in Theoretical Computer Science (ITCS), pp. 214–226, 2012.
Haixiang, G., Yijing, L., Shang, J., Mingyun, G., Yuanyue, H., and Bing, G. Learning from class-imbalanced data: Review of methods and applications. Expert Systems with Applications, 73:220–239, 2017.
Hardt, M., Price, E., and Srebo, N. Equality of opportunity in supervised learning. In Advances in Neural Information Processing Systems (NeurIPS), pp. 3315–3323, 2016.
Hashimoto, T. B., Srivastava, M., Namkoong, H., and Liang, P. Fairness without demographics in repeated loss minimization. In International Conference on Machine Learning (ICML), 2018.
Hastie, T., Montanari, A., Rosset, S., and Tibshirani, R. J. Surprises in high-dimensional ridgeless least squares interpolation. arXiv preprint arXiv:1903.08560, 2019.
He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Computer Vision and Pattern Recognition (CVPR), 2016.
Hendrycks, D. and Dietterich, T. Benchmarking neural network robustness to common corruptions and perturbations. arXiv preprint arXiv:1903.12261, 2019.
Hendrycks, D., Zhao, K., Basart, S., Steinhardt, J., and Song, D. Natural adversarial examples. arXiv preprint arXiv:1907.07174, 2019.
Hu, W., Niu, G., Sato, I., and Sugiyama, M. Does distri- butionally robust supervised learning give robust classi-fiers? In International Conference on Machine Learning (ICML), 2018.
Ioffe, S. and Szegedy, C. Batch normalization: Accelerat- ing deep network training by reducing internal covariate shift. In International Conference on Machine Learning (ICML), pp. 448–456, 2015.
Japkowicz, N. and Stephen, S. The class imbalance problem: A systematic study. Intelligent data analysis, 6(5):429– 449, 2002.
Kleinberg, J., Mullainathan, S., and Raghavan, M. Inherent trade-offs in the fair determination of risk scores. In Innovations in Theoretical Computer Science (ITCS), 2017.
Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In Proceedings of the IEEE International Conference on Computer Vision, pp. 3730– 3738, 2015.
McCoy, R. T., Pavlick, E., and Linzen, T. Right for the wrong reasons: Diagnosing syntactic heuristics in natural language inference. In Association for Computational Linguistics (ACL), 2019.
Mei, S. and Montanari, A. The generalization error of ran- dom features regression: Precise asymptotics and double descent curve. arXiv preprint arXiv:1908.05355, 2019.
Nakkiran, P., Kaplun, G., Bansal, Y., Yang, T., Barak, B., and Sutskever, I. Deep double descent: Where bigger models and more data hurt. arXiv preprint arXiv:1912.02292, 2019.
Namkoong, H. and Duchi, J. Variance regularization with convex objectives. In Advances in Neural Information Processing Systems (NeurIPS), 2017.
Opper, M. Statistical mechanics of learning: Generalization. The Handbook of Brain Theory and Neural Networks,, pp. 922–925, 1995.
Oren, Y., Sagawa, S., Hashimoto, T., and Liang, P. Distribu- tionally robust language modeling. In Empirical Methods in Natural Language Processing (EMNLP), 2019.
Rosset, S., Zhu, J., and Hastie, T. J. Margin maximizing loss functions. In Advances in neural information processing systems, pp. 1237–1244, 2004.
Sagawa, S., Koh, P. W., Hashimoto, T. B., and Liang, P. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In International Conference on Learning Representations (ICLR), 2020.
Shimodaira, H. Improving predictive inference under covari- ate shift by weighting the log-likelihood function. Journal of Statistical Planning and Inference, 90:227–244, 2000.
Soudry, D., Hoffer, E., Nacson, M. S., Gunasekar, S., and Srebro, N. The implicit bias of gradient descent on separable data. Journal of Machine Learning Research (JMLR), 19(1):2822–2878, 2018.
Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research (JMLR), 15(1):1929–1958, 2014.
Wah, C., Branson, S., Welinder, P., Perona, P., and Belongie, S. The Caltech-UCSD Birds-200-2011 dataset. Technical report, California Institute of Technology, 2011.
Wen, J., Yu, C., and Greiner, R. Robust learning under uncertain test distributions: Relating covariate shift to model misspecification. In International Conference on Machine Learning (ICML), pp. 631–639, 2014.
Yang, Z., Yu, Y., You, C., Steinhardt, J., and Ma, Y. Rethink- ing bias-variance trade-off for generalization of neural networks. arXiv preprint arXiv:2002.11328, 2020.
Zhang, C., Bengio, S., Hardt, M., Recht, B., and Vinyals, O. Understanding deep learning requires rethinking generalization. In International Conference on Learning Representations (ICLR), 2017.
Zhou, B., Lapedriza, A., Khosla, A., Oliva, A., and Tor- ralba, A. Places: A 10 million image database for scene recognition. IEEE Transactions on Pattern Analysis and Machine Intelligence, 40(6):1452–1464, 2017.
A.1. ERM models have poor worst-group error regardless of the degree of overparameterization
In the main text, we focused on reweighted models, trained with the reweighted objective on the full data (Sections 3-5), as well as subsampled models, trained on subsampled data with the ERM objective (Section 6). Here, we study the effect of overparameterization on ERM models, trained with the ERM objective on the full data. Consistent with prior work, we observe that ERM models obtain poor worst-group error (near or worse than random), regardless of whether the model is underparameterized or overparameterized (Sagawa et al., 2020). We also confirm that overparameterization helps average test error (see, e.g., Nakkiran et al. (2019); Belkin et al. (2019); Mei & Montanari (2019)).
Empirical results. We first consider the CelebA and Waterbirds dataset, following the experimental set-up of Section 3
but now training with the standard ERM objective (Equation (2)) instead of the reweighted objective (Equation (3)).
On these datasets, overparameterization helps the average test error (Figure 8). As model size increases past the point of zero training error, the average test error decreases. The best average test error is obtained by highly overparameterized models with zero training error—4.6% for CelebA at width 96, and 4.2% for Waterbirds at 6,000 random features.
In contrast, the worst-group error is consistently high across model sizes: it is consistently worse than random (>50%) for CelebA and nearly random (44%) for Waterbirds (Figure 8). These worst-group errors are much worse than those obtained by reweighted, underparameterized models (25.6% for CelebA and 26.6% for Waterbirds; see Section 3). Thus, while overparameterization helps ERM models achieve better test error, these models all fail to yield good worst-group error regardless of the degree of overparameterization.
Figure 8. The effect of overparameterization on the average and worst-group error of an ERM model. Increasing model size helps average test error, but worst-group error remains poor across model sizes.
Simulation results. We also evaluate the effect of overparameterization on ERM models on the synthetic dataset introduced in Section 4. As above, ERM models fail to achieve reasonable worst-group test error across model sizes, but improve in average test error as model size increases (Figure 8). The best average test error is obtained by a highly overparameterized model with zero training error—9.0% error at 9,000 random features—while the worst-group test error is nearly random or worse (> 48%) across model sizes.
A.2. Stronger regularization improves worst-group error in overparameterized reweighted models
In the main text, we studied models with default/weak or no regularization. In this section, we study the role of
regularization in modulating the effect of overparameterization on worst-group error by changing the hyperparameter
controls
regularization strength. Overall, we find that increasing
regularization (to the point where models do not have zero training error) improves worst-group error but hurts average error in overparameterized reweighted models. In contrast,
regularization has little effect on both worst-group and average error in the underparameterized regime.
Strong regularization improves worst-group error in overparameterized reweighted models. In the main text, we trained ResNet10 models with default, weak regularization (
) on the CelebA dataset, and unregularized logistic regression on the Waterbirds and synthetic datasets. Here, we consider strongly-regularized models with
for both types of models; unlike before, these models no longer achieve zero training error even when overparameterized. Figure 9 shows the results of varying model size on strongly-regularized ERM, reweighted, and subsampled models on the three datasets.
On all three datasets, with strong regularization, ERM models continue to yield poor worst-group test error across model sizes, with similar or worse worst-group test error compared to with weak/ no regularization. Conversely, strongly-regularized subsampled models continue to achieve low worst-group test error across model sizes.
Where strong regularization has a large effect is on reweighted models. With reweighting, we find that strong regularization improves worst-group error in overparameterized models: across all three datasets, the worst-group test error in the overparameterized regime is much lower for the strongly-regularized models than their weakly regularized or unregularized counterparts (Figure 3). These results are consistent with similar observations made in Sagawa et al. (2020). However, even though strongly-regularized overparameterized models outperform weakly-regularized overparameterized models, overparameterization can still hurt the worst-group error in strongly-regularized reweighted models. On the CelebA and synthetic datasets, with , the best worst-group error is still obtained by an underparameterized model for the CelebA and synthetic datasets, though overparameterization seems to help worst-group error on the Waterbirds dataset at least in the range of model sizes studied.
Figure 9. Strongly-regularized models have lower worst-group error than their weakly-regularized counterparts in the overparameterized regime (Figure 3). Even under strong regularization, increasing model size can hurt the worst-group error on the CelebA (top) and synthetic (bottom) datasets, although overparameterization seems to improve worst-group error in the Waterbirds datase (middle) for the range of model sizes studied.
Overparameterized models require strong regularization for worst-group test error but not average test error. Given a fixed overparameterized model size, how does its performance change with the regularization strength
? We study this with the logistic regression model on the Waterbirds and synthetic datasets, using a model size of m = 10, 000 random features and varying the
regularization strength from
Results are in Figure 10. As before, ERM models obtain poor worst-group error regardless of the regularization strength, and subsampled models are relatively insensitive to regularization, achieving reasonable worst-group error at most settings of
For reweighted models, however, having the right level of regularization is critical for obtaining good worst-group test error. On both datasets, the best worst-group test error is obtained by strongly-regularized models that do not achieve zero training error. In contrast, increasing regularization strength hurts average error, with the best average test error attained by models with nearly zero regularization.
Figure 10. The effect of regularization on overparameterized random features logistic regression models (m = 10, 000). ERM models (left) do consistently poorly while subsampled models (right) do consistently well on worst-group error. For reweighted models (middle), the best worst-group error is obtained by a strongly-regularized model that does not achieve zero training error.
regularization affects where worst-group test error plateaus as model size increases. In the above experiments, we kept either model size or regularization strength fixed, and varied the other. Here, we vary both: we consider
regularization strengths
and investigate the effect of increasing model size for each
plot the results for Waterbirds and the synthetic dataset in Figure 11 and Figure 12 respectively.
For reweighted models, the results match what we observed above. Strengthening regularization reduces the detrimental effect of overparameterization on worst-group error. For any fixed model size in the overparameterized regime, the worst-group test error improves as
increases up to a certain value. Worst-group test error seems to plateau at different values as model size increases, depending on the regularization strength, though we note that it is possible that further increasing model size beyond the range we studied might lead models with different regularization strengths to eventually converge. Further empirical studies as well as theoretical characterization of the interaction between regularization and overparameterization are needed to confirm this phenomenon.
Given sufficiently large (e.g.,
for both Waterbirds and synthetic datasets), overparameterized models seem to outperform underparameterized models, at least for the range of model sizes studied. However, we caution that this trend does not seem to hold on the CelebA dataset (Figure 9).
Finally, in contrast with its effects on overparameterized models, regularization seems to only have a modest effect on worst-group test error in the underparameterized regime.
Figure 11. The effect of overparameterization on models with different regularization strengths
on the Waterbirds dataset. Different regularization strengths are shown in different colors, with training and test errors plotted in light and dark colors, respectively.
Figure 12. The effect of overparameterization on models with different regularization strengths
on the synthetic dataset. The plotting scheme follows that of Figure 11.
A.3. Overparameterization helps average test error on the synthetic data regardless of
Figure 13 shows how the average test error changes as a function of model size under different settings of the majority fraction and the spurious-core ratio
on the synthetic dataset introduced in Section 4. As expected, overparameterization helps the average test error regardless of SCR and the majority fraction.
Figure 13. The effect of overparameterization on average error of a reweighted model on synthetic data. Different values of are plotted in different colors, with training and test errors plotted in light and dark colors, respectively. Across all values of
overparameterization helps the average test error.
A.4. Comparison between implicit and explicit implicit memorization
To motivate the explicit-memorization setting, we ran some brief experiments to show that in the overparameterized regime, linear models in the explicit-memorization setting behave similarly to random projection (RP) models in the implicit-memorization setting, with in the latter scaled up by a factor of d (Figure 14). Recall that in the latter,
is distributed as
. Roughly speaking, all the information about y is contained in the mean
, which is distributed as
. In the explicit-memorization setting, we can view
as equivalent to
in the implicit-memorization setting (and similarly for
), explaining the quantitative fit observed in Figure 14.
However, in the highly underparameterized regime, the RP models do poorly because of model misspecification (owing to a small number of random projections), whereas the linear models can still learn to use and therefore do well.
Figure 14. The effect of overparameterization on the worst-group test error for linear models in the explicit-memorization setting () and random projection models in the implicit-memorization setting (
The models agree in the overparameterized regime.
A.5. Experimental details
Waterbirds and CelebA datasets. For the CelebA dataset, we use the official train-val-test split from Liu et al. (2015), with the Blond Hair attribute as the target y and the Male as the spurious association a.
For the Waterbirds dataset, we follow the setup in Sagawa et al. (2020); for convenience, we reproduce some details of how it was constructed here. This dataset was obtained by combining bird images from the CUB dataset (Wah et al., 2011) with backgrounds from the Places dataset (Zhou et al., 2017). The CUB dataset comes with annotations of bird species. For the Waterbirds dataset, each bird was labeled was a waterbird if it was a seabird or waterfowl in the CUB dataset; otherwise, it was labeled as a landbird. Bird images were cropped using the provided segmentation masks and placed on either a land (bamboo forest or broadleaf forest) or water (ocean or natural lake) background obtained from the Places dataset.
For Waterbirds, we follow the same train-val-test split as in Sagawa et al. (2020). Note that in these validation and test sets, landbirds and waterbirds are uniformly distributed on land and water backgrounds so that accuracy on the rare groups can be more accurately estimated. When calculating average test accuracy, we therefore first compute the average test accuracy over each group and then report a weighted average, with weights corresponding to the relative proportion of each group in the skewed training dataset.
We post-process Waterbirds by extracting feature representations taken from the last layer of a ResNet18 model pre-trained on ImageNet. We use the Pytorch torchvision implementation of the ResNet18 model for this. All models on the Waterbirds dataset in our paper are logistic regression models trained on top of this (fixed) feature representation.
ResNet. We used a modified ResNet10 with variable widths, following the approach in Nakkiran et al. (2019) and extending the torchvision implementation. We trained all ResNet10 models with stochastic gradient descent with momentum of 0.9 and a batch size of 128, with the regularization parameter
was passed in to the optimizer as the weight decay parameter. In the experiments in the main text, we used the default setting of
. We used a fixed learning rate instead of a learning rate schedule and selected the largest learning rate for which optimization was stable, following Sagawa et al. (2020). This resulted in learning rates of 0.01 and 0.0001 for
, respectively, across all training procedures. As in the original ResNet paper (He et al., 2016), we used batch normalization (Ioffe & Szegedy, 2015) and no dropout (Srivastava et al., 2014), and for simplicity, we trained all models without data augmentation.
We trained for 50 epochs for ERM and reweighted models and 500 epochs for subsampled models (due to smaller number of examples per epoch). We found that worst-group error can be unstable across epochs due to the small sample size and relatively large learning rate, so in our results we report the error averaged over the last 10 epochs.
Logistic regression. We used the logistic regression implementation from scikit-learn, training with the L-BFGS solver until convergence with tolerance 0.0001, and setting the regularization parameter as . For unregularized models, we set
for numerical stability.
A.6. Subsampling
Formally, given a set of groups G and a dataset D comprising a set of n training points with their group identities , the subsampling procedure involves two steps. First, we group training points based on group identities:
For each group g, we select a subset Duniformly at random from D
such that each subset has the same number of points as the smallest group in the training set. We form a new dataset D
by combining these subsets:
Note that Dis group-balanced, with
. We then train a model by minimizing the average loss on D
Since Dis group-balanced, the reweighted training loss (Equation 3) has the same weight on all training points and minimizing the reweighted objective on D
is equivalent to minimizing the average loss objective above.
Here, we detail the proof of Theorem 1 presented in Section 5. We structure the proof by splitting Theorem 1 into two smaller theorems: one for the overparameterized regime (Appendix B.2), and another for the underparameterized regime (Appendix B.3).
B.1. Notation and definitions.
We denote the separate components of the weight vector
Further, by the representer theorem, we decompose
Note that is equivalent to the
referred to in the main text. Recall that we define memorization of each training point
by the weight
as follows.
Definition 2 (-memorization). Consider a separator
on training data
. For some constant
, we say that a model
-memorizes a training point if
The component serves to “memorize”
is sufficiently large, as it affects the prediction on
not on any other training or test points (because noise vectors are nearly orthogonal when N is large). In the proof, we set the constant
appropriately (based on other parameter settings in Theorem 1) to get the required result.
Finally, let denote the indices of training points in the majority and minority group respectively.
B.2. Overparameterized regime
In our explicit-memorization set-up, sufficiently overparameterized models provably have high worst-group error under certain settings of as stated in Theorem 1 (restated below as Theorem 2).
Theorem 2. For any ,
and
, there exists
such that for all
(overparametrized regime), with high probability over draws of the data,
where is the max-margin classifier.
In Section 5, we sketched key ideas in the proof by considering special families of separators: because the minimum-norm inductive bias favors less memorization, models can prefer to learn the spurious feature and memorize the minority examples (entailing high worst-group error), instead of learning the core feature and memorizing some fraction of all training points (possibly attaining reasonable worst-group error). We now provide the full proof of Theorem 2, generalizing the above key concepts by considering all separators.
Proof. Recall from Section 5 that we consider the maximum-margin classifier
In other words, is the minimum-norm separator, where separator is a classifier with zero training error and required margins, satisfying
. We analyze the worst-group error of the minimum-norm separator
as outlined below:
1. We first upper bound the fraction of majority examples memorized by the minimum-norm separator that there exists a separator that can use spurious features and needs to memorize only the minority points (Lemma 1) for the parameter settings in Theorem 2 where
is sufficiently small. Since the norm of a separator is roughly scales with the number of points memorized (
), we have an upper bound on the number of training points memorized by
. Since the number of majority points is much larger than the number of minority points, this says that only a small fraction of majority points could be memorized by
2. Next, we observe that since the core feature is noisy as per the parameter setting in Theorem 2, if we do not use the spurious feature, a constant fraction of majority points have to be memorized if spurious features are not used. Conversely, if less than this fraction of majority points can be memorized, the separator must use spurious features. Since using spurious features leads to higher worst-group test error, this reveals a trade-off between the worst-group test error of a separator and the fraction of majority points that it memorizes at training time. Succinctly, smaller fraction memorized implies the use of spurious features which in turn implies higher worst-group test error. Smaller worst-group test error requires eliminating the use of spurious features which would lead to a large fraction of majority points requiring memorization in order for a classifier to be a separator. We formalize the above trade-off between the worst-group test error and fraction of majority examples to be memorized in Proposition 3.
Combining the two steps together, since memorizes only a small fraction of majority points by virtue of being the minimum norm separator,
suffers high worst-group test error.
We now formally prove Theorem 2, invoking propositions that we prove in subsequent sections.
In the first part of the proof, we show that the minimum-norm separator “memorizes” a small fraction of the majority examples. Formally, we study the quantity
defined as follows.
Definition 3. Consider a separator on training data
. Let
be the fraction of training examples that
-memorizes in the majority groups:
We provide an upper bound on (Lemma 4) by first bounding
and then bounding
in terms of
Bounding
Lemma 1. There exists a separator that satisfies
. The norm of this separator gives a bound on
as follows. For the parameter settings under Theorem 2, with high probability, we have
for constants
Proof. In order to get an upper bound on , we compute the norm of a particular separator. Concretely, we consider a separator
of the following form:
First, because we are interested in that does not use the core feature and relies on the spurious feature instead, we let
. We set the value u appropriately so that none of the majority points are memorized (corresponding to
). However since the spurious correlations are reversed in the minority points and
, the minority points have to be memorized. For simplicity, we set
for all
Now it remains to select appropriate values of constants is satisfied for all training examples.
For majority points, this involves setting u large enough such that the less noisy spurious feature can be used to obtain the required margin. Without loss of generality, assume . Formally, for
The first inequality follows from the fact that is small enough under the parameter settings of Theorem 2 to allow a uniform bound on
). The second inequality follows from setting the number of random features N to be large enough so that the noise features are near orthogonal (Lemma 8). Conversely, we have
Notice that the condition in Equation 23 requires that u be greater than 0. Since the minority points have spurious attribute , we need to set s to be large enough so that
as defined above separates the minority points. Just as before, we set
The steps are similar to the condition for majority points, with the key difference that the contribution from the noise term involves
Conversely, we have
A set of parameters that satisfies both conditions above Equation 24 and Equation 23 is the following:
We use the fact that (From Lemma 9).
Finally, we have w.h.p,
Bounding in terms of
Lemma 2. For a separator with bounded
for all i = 1, . . . , n, its norm can be bounded with high probability as
under the parameter settings of Theorem 2.
Proof. The result follows bounded norms (Lemma 9), bounded dot products (Lemma 8), and the definition of
(Definition 3).
We now apply Lemma 1 and Lemma 2 in order to bound , showing that the fraction of majority points that are memorized is small for appropriate choice of
To invoke Lemma 2, we first show that the coefficient is bounded above with high probabiltity.
Lemma 3. Under the parameter settings of Theorem 2, with high probability, is bounded above for i = 1, . . . , n as
From the upper bound on
Since
Now, we are ready to show that
Lemma 4. Under the parameter settings of Theorem 2, the following is true with high probability.
Proof. Applying Lemma 2 to by invoking the bounds on
with high probability. Putting this together with Lemma 1, we have
where in the last step we substitute the constants and
Lemma 5. With probability
where a is the spurious attribute.
This follows from standard subgaussian concentration and union bound over
Lemma 6. For a vector
Lemma 7. For two vectors , by Hoeffding’s inequality, we have
Corollary 1. Combining Lemma 6 and Lemma 7, we get
Lemma 8. For , with probability greater than
This follows from Corollary 1 and union bound over pairs of training points.
Lemma 9. For , with probability greater than
This follows from Lemma 6 and union bound over n training points. In particular, we can set for large enough N.
In the previous section, we proved that , the fraction of majority training samples that can have coefficient on the noise vectors greater than
in the max margin separator
is bounded for suitable value of
. We showed this using the fact that the norm of
is the smallest among all separators and the observation that the squared norm of a separator roughlty scales proportional the number of training points that have large coefficient along the noise vectors.
What does small imply? We now show that the bound on
has an important consequence on the worst-group error Err
; low
would imply high worst-group error Err
. We show that there is a trade-off between the worst-group test error of a separator and the fraction of majority points that it “memorizes” at training time. If a model that has low worst-group test error must use the core feature and not the spurious feature, and to obtain zero training error such a model would memorize a potentially large fraction of majority and minority points. In contrast, if the model instead uses only the spurious feature, then the worst-group test error would be high, but it would memorize only a small fraction of majority examples at training time; because we assume that the spurious feature is much less noisy than the core feature (
), much fewer majority examples would need to be memorized. To summarize, a large
would require smaller fraction of majority points to be memorized
but increase the worst-group test error Err. We formalize the above trade-off between the worst-group error and fraction of majority examples to be memorized in Proposition 3.
Proposition 3. For the minimum norm separator , under the parameter settings of Theorem 2, with high probability,
for some constants the Gaussian CDF.
For any separator that spans the training points and satisfies
under the parameter settings of Theorem 2, with high probability,
for some constants the Gaussian CDF.
We prove Proposition 3 in Section B.2.5.
As mentioned before, we see that the spurious component weight has opposite effects on the two quantities; Err
increases with increase
decreases with increase in
. This dependence can be exploited to relate the two quantities to each other as follows.
In other words, if the is low, then Err
would need to be high.
Recall from part 1 that for appropriate choice of
, and from part 2 the trade-off between
(Equation (50)). As a final step, we need to bound the quantities on the RHS of Equation (50). All the constants are small, and
) which allows us to write
We have hence proved that the minimum-norm separator incurs high worst-group error with high probability under the specified conditions.
Proposition 3. For the minimum norm separator , under the parameter settings of Theorem 2, with high probability,
for some constants the Gaussian CDF.
For any separator that spans the training points and satisfies
under the parameter settings of Theorem 2, with high probability,
for some constants the Gaussian CDF.
Proof. We derive the two bounds below.
Worst-group test error
We bound the expected worst-group error Err, which is the expected worst-group loss over the data distribution. Below, we lower bound the worst-group error Err
by bounding the error on a particular group: minority positive points which have label y = 1 and spurious attribute
. The test error is the probability that a test example x from this group gets misclassified, i.e.
In the last step, we rewrite for convenience
We use the properties of high-dimensional Gaussian random vectors to bound the quantity . Recall that
can be written as
From Lemma 3, we know that max
probability for some small constants
denote the event that this high probability event where the dot product
. Using the fact that
which follows from simple algebra, we have
From the expression above, we see that Errincreases as the spurious component
increases. This is because in the minority group, the spurious feature is negatively correlated with the label.
Fraction of memorized training examples in majority groups
We now compute a lower bound on , which is the number of majority points (where a = y) that are “memorized.” Intuitively, we want to show that the fraction depends on
. The more the core feature is used relative to the spurious feature, the larger fraction of points need to be memorized because the core feature is more noisy.
First, consider a separator with some core and spurious components
. Recall that
and
by the definition of separators. For a given
, we want to bound the fraction of majority points (a = y) which can have
. We focus only on separators with bounded memorization, i.e. those that satisfy
. Note that from Lemma 3, w.h.p., the mininum-norm separator
satifies this condition.
We bound the above by bounding a related quantity: the fraction of points that are memorized in the training distribution in expectation. We then use concentration to relate it to the fraction of the training set.
Formally, we have fixed quantities . The training set is generated as per the usual data generating distribution. As before, we are interested in separators on the training set. For any majority training point, the coefficient
in a separator is a random variable. Since training point i is separated, we have
From Lemma 8, Lemma 6, and the condition on , this implies with high probability that
for some constant . Conditioning on the high probability event just as before (
), we get
for some . Finally, we connect to
which is the finite sample version of the quantity
. By DKW, we know that the empirical CDF converges to the population CDF. Under the conditions of Theorem 2, which lower bounds the number of majority elements, we have with high probability,
for constants
Proposition 1 (Norm of models using the spurious feature). When satisfy the conditions in Theorem 1, there exists
such that for all
, with high probability, there exists a separator
for some constants
Proof. The proposition follows directly from Lemma 1.
Proposition 2 (Norm of models using the core feature). When satisfy the conditions in Theorem 1 and
there exists
such that for all
, with high probability, all separators
for some constant
Proof. To bound the norm for all , we provide a lower bound on the norm of the minimum-norm separator in the set
We bound the in two steps:
1. We first provide a lower bound for in terms of the fraction of training points memorized
(defined formally below) in Corollary 2.
2. We then provide a lower bound for in Corollary 3.
We first formally define
Definition 4. For a separator on training data
, let
be the fraction of training examples that
-memorizes:
Lemma 10. For a separator with bounded
for all i = 1, . . . , n, its norm can be bounded with high probability as
Proof. Similarly to the proof of Lemma 2, the result follows bounded norms (Lemma 9), bounded dot products (Lemma 8), and the definition of
Corollary 2. With high probability,
Proof. The result follows from applying Lemma 10 to , invoking the bounds on any individual component
obtained below in Lemma 11.
Below, we bound , where
is the component of training point i to the classifier
via the representer theorem.
Lemma 11. With high probability, can be bounded as follows.
Proof. As a first step, we upper bound the norm of by the norm of another separator
the fact that
is the minimum-norm separator in
. In particular, we construct a separator
that “memorizes” all training points, of the following form:
This is analogous to the construction of (Lemma 1), and similar calculations can be used to obtain a suitable value
to ensure that
is a separator with high probability. We provide it below for completeness. We show that the following condition is sufficient to satisfy the margin constraints
with high probability:
for . We obtain the above condition by applying Lemma 8 and Lemma 9 to the margin condition.
Thus, we can construct by setting some constant
Now that we have constructed
, we can bound the norm of the minimum norm separator
by the norm of
. The following is true with high probability,
Finally, we bound for all i by bounding max
. As we showed in the proof of
Lemma 3, following is true with high probability:
Combined with the upper bound on (Equation (80)), we have
Bounding
Corollary 3. Under the parameter settings of Theorem 2, with high probability,
for some constants is the Gaussian CDF.
Proof. The result follows from applying Proposition 3 (which computes a bound on the majority fraction of points that is memorized) to
, invoking Lemma 11, and plugging in
. Note that when
,
Finally, the above bound on translates to a bound on the norm
via simple algebra. For
that satisfies
Plugging the above lower bound into the bound on from Corollary 2, we have
for some
B.3. Underparameterized regime
So far, we have studied the overparameterized regime for the data distribution described in Section 5. In the overparameterized setting, where the dimension of noise features N is very large, logistic regression (both ERM and reweighted) leads to max-margin classifiers. We showed that for some setting of parameters , the robust error of such max-margin classifiers can be > 2/3, worse than random guessing. How does the same reweighted logistic regression perform in the underparameterized regime? We focus on the setting where N = 0. In this setting, the data is two-dimensional, and w.h.p., the training data is not linearly separable unless
. Consequently, the learned model
that minimizes the reweighted training loss is not generally a max-margin separator.
For intuition, consider the following two sets of models, which are analogous to what we considered in Equation 12 in the main text for the overparameterized regime:
The first set comprises models that use the spurious feature but not the core feature, and the second set
comprises models that use the core feature but not the spurious feature. Models in
that exclusively use
will have high training loss on the minorities since the minority points cannot be memorized. Due to upweighting the minorities, these models will have high reweighted training loss. On the other hand, models in
exclusively use the core features that are informative for the label y across all groups. Hence they obtain reasonable loss across all groups and have smaller reweighted training loss than models in
We will show in this section that the population minimizer of the reweighted loss is indeed in and bound the asymptotic variance of the reweighted estimator, leading to the final result in Theorem 1. Our approach is to study the asypmtotic behavior of the reweighted estimator when the number of data points
Data distribution. We first recap the data generating distribution (described in Section 5).
For fraction of points, we have a = y (majority points) and for
fraction of points, we have
points).
Reweighted logistic loss. Let be the fraction of the majority group points and
be the fraction of minority points. In order to use standard results from the asymptotics of M-estimators, we rewrite the reweighted estimator (defined in Section 2) as the minimizer of the following loss over n training points
We follow the standard steps of asymptotic analysis where we:
1. Compute the population minimizer that satisfies
2. Bound the asymptotic variance
Proposition 4. For the data distribution under study, the population minimizer that satisfies
is the following.
This is a very important property in the underparameterized regime: the population minimizer has the best possible worst-group error by only using the core feature and not the spurious feature.
Proposition 5. The asymptotic distribution of the reweighted logistic regression estimator is as follows.
For
for some constants
We see that the asymptotic variance increases as increases. This is expected because the reweighted estimator upweights the minority points by inverse of group size. As these weights increase, the variance also increases. However, as we noted before, since the population minimizer has small worst-group error, for large enough training set size, we get small worst-group error since the asymptotic variance is finite (for fixed
) and the estimator approaches the population minimizer.
We now prove Theorem 1 for the underparameterized regime, restated as Theorem 3 below.
Theorem 3. In the underparameterized regime with N = 0, for , and
, in the asymptotic regime with
Proof. We now put the two Propositions 5 and 4 together. We have and
for
, i.e the estimator is very close to the population minimizer. This follows from setting
to their corresponding values and setting
to be large enough. In order to compute the worst-group error, WLOG consider points with label y = 1 (labels are balanced in the population). For a point from the majority group, the probability of misclassification is as follows.
where
Similarly, for the minority group, the probability of misclassification is
Therefore, the worst-group error of can be bounded as.
where is the Gaussian CDF. Substituting
gives the required result that Err
. In contrast, in the overparameterized regime where
, even for very large n, the reweighted estimator has high worst-group error, as shown in Theorem 1.
We now provide the proofs for Proposition 4 and Proposition 5 which mostly follow from straightforward algebra.
Proposition 4. For the data distribution under study, the population minimizer that satisfies
is the following.
Proof. For convenience, we compute expectations over the majority and minority groups separately and express the population loss Las the weighted sum of the two terms. Recall that we denote
We use the following expression for computing the population gradient.
Combining the definition of the reweighted loss and population losses (Equation 91 and Equation 102) with the gradient expression above gives the following.
Now we compute . First we compute wrt the spurious attribute
. For convenience, let
Replacing
Now we take the weighted combination of , based on the fraction of the majority and minority samples in the population, which makes the two terms cancel out.
Now we compute
Similarly, we get and hence proved that
Lemma 12. The following is true.
We now compute the asymptotic variance which involves computing
Proof. First, we show that the off-diagonal entries of
Now, we bound the diagonal elements.
Finally,
Lemma 13. The following is true.
Proof. We use the following expression for computing the population gradient.
Recall the definition of the population majority and minority losses (Equation 102).
Like previously, we first compute the off-diagonal entries.
Now, we bound the diagonal entries. Recall that
Finally, we calculate as follows.
For
for some constants
Proof. By asymptotic normality, we have . Combining Lemma 12 and Lemma 13, we get the expression in Equation 96. Each term is decreasing in
, and hence we get the final result by substituting
to obtain the constants
(and noting that