The Low-Rank Simplicity Bias in Deep Networks


Minyoung Huh1
Hossein Mobahi2
Richard Zhang3
Brian Cheung1 4
Pulkit Agrawal1
Phillip Isola1

1MIT CSAIL
2Google Research
3Adobe Research
4MIT BCS

code

paper

arxiv

cite


Abstract

Modern deep neural networks are highly over-parameterized compared to the data on which they are trained, yet they often generalize remarkably well. A flurry of recent work has asked: why do deep networks not overfit to their training data? In this work, we make a series of empirical observations that investigate and extend the hypothesis that deeper networks are inductively biased to find solutions with lower effective rank embeddings. We conjecture that this bias exists because the volume of functions that maps to low effective rank embedding increases with depth. We show empirically that our claim holds true on finite width linear and non-linear models on practical learning paradigms and show that on natural data, these are often the solutions that generalize well. We then show that the simplicity bias exists at both initialization and after training and is resilient to hyper-parameters and learning methods. We further demonstrate how linear over-parameterization of deep non-linear models can be used to induce low-rank bias, improving generalization performance on CIFAR and ImageNet without changing the modeling capacity.



Results


Observation 1
Randomly initialized deep nets are biased to correspond to Gram matrices with a low effective rank.



Observation 2
Deep neural networks trained with gradient descent also learn to map data to simple embedding with low effective rank.



Observation 3
Deep neural networks trained with common and natural choices of optimizers also exhibit the low-rank embedding bias.



Observation 4
Deep neural networks are biased towards learning low effective-rank embeddings and are insensitive to initialization.



Try our PyTorch code

Install our code [github]

      >>> git clone https://github.com/minyoungg/overparam
      >>> cd overparam
      >>> pip install .
      
Integrate to your existing PyTorch code base

      from overparam import OverparamLinear, OverparamConv2d

      # over-parameterized nn.Linear layer
      layer = OverparamLinear(32, 32, depth=4)

      # over-parameterized nn.Conv2d layer (3 layers with 3x3, 3x3, 1x1 kernels)
      layer = OverparamConv2d(32, 64, kernel_sizes=(3, 3, 1), stride=1, padding=1)
      
Automatic linear over-parameterization of existing models

      import torchvision.models as models
      from overparam.utils import overparameterize

      model = models.alexnet()
      model = overparameterize(model, depth=2)
      
To compute the effective rank of 2D matrix using PyTorch

      def effective_rank(w):
          s = w.cpu().svd(compute_uv=False)[1]
          s_hat = s / s.sum()
          return - (s_hat * s_hat.log()).sum()
      


Acknowledgements

We would like to thank Anurag Ajay, Lucy Chai, Tongzhou Wang, and Yen-Chen Lin for reading over the manuscript and Jeffrey Pennington and Alexei A. Efros for fruitful discussions. Minyoung Huh is funded by DARPA Machine Common Sense and MIT STL. Brian Cheung is funded by an MIT BCS Fellowship.

This research was also partly sponsored by the United States Air Force Research Laboratory and the United States Air Force Artificial Intelligence Accelerator and was accomplished under Cooperative Agreement Number FA8750-19-2-1000. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the United States Air Force or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes, notwithstanding any copyright notation herein.

Website template edited from Colorful Colorization.