How Come Deep Neural Networks Generalize?

This is a short review of the paper Understanding Deep Learning Requires Rethinking Generalization by Zhang et al, published in ICLR 2017, which I wrote as part of the application to the Google Research India AI Summer School.

The paper Understanding Deep Learning Requires Rethinking Generalization by Zhang et al, examines the question: Why do deep convolutional neural networks generalize well?

More precisely, models like AlexNet and its successors are able to achieve 70% test accuracy (90% considering top-5 guesses) on image classification using the CIFAR10 and ImageNet datasets.

In statistical learning theory, bounds on whether a hypothesis class can generalize to unseen data depend on whether it can shatter (informally, represent every possible input-output mapping) the training data. Intuitively, if the model can memorize the dataset it need not learn anything useful for unseen data. When dealing with highly expressive model classes, regularization methods (such as weight decay) are used to penalize expressivity and hopefully force the model to learn some kind of signal in the data, and not memorize.

Zhang et al show that this can’t be going on in deep conv nets, because these networks easily learn random labels to 100% train accuracy, even with regularization! They can indeed memorize random data, so regularisation is not doing its (theoretical) job of bounding expressivity. (One exception: AlexNet with both weight decay and data augmentation failed to converge on CIFAR10) Another point against the role of regularisation is that dropping it reduces test accuracy by only about 5%, so it’s not the main driver of generalisation ability.

Another hypothesis is that something about the structure of convolution neural networks makes it easy for them to learn images. (perhaps filters help pick out edges)

However, the authors show that conv nets easily learn images of Gaussian noise. It did take longer to learn random labels than random noise, which may indicate that some kind of forced declustering of natural clusters occurs.

The authors also find that regularisation seems to have benefits in optimization and prove a theorem showing that two layer ReLU networks can memorize n data points in d dimensions using 2n + d parameters. Note that the models used have about 1.5 million parameters, while ImageNet has 1.3 milion images (50000 for CIFAR10), so memorisation is certainly a plausible outcomes.

Thus we are still left with the question: How do we explain the impressive generalization ability of these models?

The authors analyse stochastic gradient descent in the context of linear models, and show that it finds the minimum-norm weights, which is promising, as it seems like SGD is implementing some kind of implicit l-2 regularisation. However, preprocesing the data leads to solutions with higher weight norms but better test accuracy.

Clearly, statistical learning theory cannot currently explain why deep (and large) models generalize well. So why do they?

One hypothesis (related to the manifold hypothesis) is that the test data simply isn’t that different from the training data, and these models are simply interpolating the test data based on the clusters they’ve learned. The models have enough parameters to disentangle the data manifolds of the various classes. (see Chris Olah’s post https://colah.github.io/posts/2014-03-NN-Manifolds-Topology) Of course, this still leaves us with the question of how these highly expressive models manage to find the correct manifold representation in the first place.

It’s also possible that there’s some overfitting going on. Madry et al’s recent paper From ImageNet to Image Classification: Contextualizing Progress on Benchmarks finds that roughly 20% of ImageNet images contain more than one object, making the correct annotation unclear. They note that for a third of all multi-object images the ImageNet label does not even match what new annotators deem to be the main object in the image. Yet, even in these cases, models still successfully predict the ImageNet label (instead of what humans consider to be the right label) for the image. This lends support to the theory that ImageNet images are very similar to each other. While considering top-5 accuracy helps alleviate issues with multi-object images, it leads us to overestimate accuracy on single-object images.

To conclude, it’s interesting that this paper was written in 2016, and as far as I can tell we still don’t really know why deep conv nets generalize! I’m pretty optimistic that rapid improvements in visualization methods will help us see what’s going on more clearly.

Leave a Reply

Your email address will not be published. Required fields are marked *