We provide a new perspective on momentum learning, with implications for how and when it is beneficial and how it interacts with nonstationarity in the task environment.
We provide a normative grounding for multiscale learning in terms of Bayesian inference over 1/f noise. Our starting point is a generative model of 1/f noise as a sum of diffusion processes at different timescales.
We test our methods in online prediction and classification tasks with nonstationary distributions. In online learning, nonstationarity often manifests as poorer generalization performance on future data versus held-out data from within the training interval.
Natural environments have temporal structure at multiple timescales. This property is reflected in biological learning and memory but typically not in machine learning systems. We advance a multiscale learning method in which each weight in a neural network is decomposed as a sum of subweights with different learning and decay rates. Thus knowledge becomes distributed across different timescales, enabling rapid adaptation to task changes while avoiding catastrophic interference.
First: we prove previous models that learn at multiple timescales, but with complex coupling between timescales, are equivalent to multiscale learning via a reparameterization that eliminates this coupling. The same analysis yields a new characterization of momentum learning, as a fast weight with a negative learning rate.
Second: We derive a model of Bayesian inference over 1/f noise, a common temporal pattern in many online learning domains that involves long-range (power law) autocorrelations. The generative side of the model expresses 1/f noise as a sum of diffusion processes at different timescales, and the inferential side tracks these latent processes using a Kalman filter. We then derive a variational approximation to the Bayesian model and show how it is an extension of the multiscale learner. The result is an optimizer that can be used as a drop-in replacement in an arbitrary neural network architecture.
Third: We evaluate the ability of these methods to handle nonstationarity by testing them in online prediction tasks characterized by 1/f noise in the latent parameters. We find that the Bayesian model significantly outperforms online stochastic gradient descent and two batch heuristics that rely preferentially or exclusively on more recent data. Moreover, the variational approximation performs nearly as well as the full Bayesian model, and with memory requirements that are linear in the size of the network.
Our analytic and simulation results demonstrate how online learning performance in nonstationary environments can be improved by incorporating a model of temporal structure. The Bayesian 1/f model amounts to distributing knowledge across multiple timescales, and the variational EKF enables approximate implementation in a neural network using subweights with different learning and decay rates. The variational EKF extends the multiscale optimizer, which is closely related to previous models in both neuroscience and ML, and in some cases is equivalent to them despite being simpler in having no coupling between timescales.
We have implemented the variational EKF optimizer in JAX in a format compatible with Optax. In the MNIST simulations of Section 5.3, we find our optimizer code (with 8 timescales) is actually 1.6% faster than Optax’s off-the-shelf SGD, in compute time per example. Note also that the multiplexing of subweights is not expensive relative to current optimizers (e.g., Adam; Kingma & Ba, 2015), which also store multiple variables for each weight.
In sum, the multiscale optimizer and variational EKF enjoy a combination of normative, heuristic, and biological justification, good performance, and computational efficiency. Our ongoing work aims to extend the theory in several ways. Chang et al. (2022) compare the present variational method to the fully-decoupled EKF of Puskorius & Feldkamp (2003). Another possible variational method is to assume a block-diagonal matrix that maintains covariance information only between subweights (timescales) within each weight, so that computational complexity still scales linearly with network size. Finally, the present method is not limited to 1/f noise but generalizes to other power laws (e.g., 1/fβ ) by appropriate choice of the timescales τi and noise variances σi in the generative model (see Appendix B). If, for example, data or theory is available bearing on the power spectrum of the dynamics in a given domain, the optimizer could be tuned accordingly.