Theory of gating in recurrent neural networks
Kamesh Krishnamurthy
Princeton University
Mon, Nov. 22nd 2021, 16:00-17:00
Salle Claude Itzykson, Bât. 774, Orme des Merisiers
Recurrent neural networks (RNNs) are powerful dynamical models, widely used in machine learning (ML) for processing sequential data, and in neuroscience, to understand the emergent properties of networks of real neurons. Prior theoretical work in understanding the properties of RNNs has focused on networks with additive interactions. However, gating – i.e. multiplicative – interactions are ubiquitous in real neurons, and gating is also the central feature of the best-performing RNNs in ML. Here, we study the consequences of gating for the dynamical behavior of RNNs. We show that gating offers flexible control of two salient features of the collective dynamics: i) timescales and ii) dimensionality. The gate controlling timescales leads to a novel, marginally stable state, where the network functions as a flexible integrator. Unlike previous approaches, gating permits this important function without parameter fine-tuning or special symmetries. Gates also provide a flexible, context-dependent mechanism to reset the memory trace, thus complementing the memory function. The gate modulating the dimensionality can induce a novel, discontinuous chaotic transition, where inputs push a stable system to strong chaotic activity, in contrast to the typically stabilizing effect of inputs. At this transition, unlike additive RNNs, the proliferation of critical points (topological complexity) is decoupled from the appearance of chaotic dynamics (dynamical complexity). The rich dynamics are summarized in phase diagrams, thus providing a map for principled parameter initialization choices to ML practitioners. Finally, we develop a field theory for gradients that arise in training, by combining the adjoint formalism from control theory with the dynamical mean-field theory. This paves the way for the use of powerful field theoretic techniques to study training and gradients in large RNNs.