3
Synaptic metaplasticity in binarized neural networks Axel Laborieux 1 , Maxence Ernoult 2 , Tifenn Hirtzlin 1 , Damien Querlioz 1 Summary Unlike the brain, artificial neural networks, includ- ing state-of-the-art deep neural networks for com- puter vision, are subject to “catastrophic forget- ting” [1]: they rapidly forget the previous task when trained on a new one. Neuroscience suggests that biological synapses avoid this issue through the process of synaptic consolidation and meta- plasticity : the plasticity itself changes upon re- peated synaptic events [2, 3]. In this work, we show that this concept of metaplasticity can be transferred to a particular type of deep neural net- works, binarized neural networks (BNNs) [4], to re- duce catastrophic forgetting. BNNs were initially developed to allow low-energy consumption imple- mentation of neural networks. In these networks, synaptic weights and activations are constrained to {-1, +1} and training is performed using hid- den real-valued weights which are discarded at test time. Our first contribution is to draw a paral- lel between the metaplastic states of [2] and the hidden weights inherent to BNNs. Based on this insight, we propose a simple synaptic consolida- tion strategy for the hidden weight. We justify it using a tractable binary optimization problem, and we show that our strategy performs almost as well as mainstream machine learning approaches to mitigate catastrophic forgetting, which minimize task-specific loss functions [5], on the task of learn- ing pixel-permuted versions of the MNIST digit dataset sequentially. Moreover, unlike these tech- niques, our approach does not require task bound- aries, thereby allowing us to explore a new set- ting where the network learns from a stream of data. When trained on data streams from Fash- ion MNIST or CIFAR-10, our metaplastic BNN outperforms a standard BNN and closely matches the accuracy of the network trained on the whole dataset. These results suggest that BNNs are more than a low precision version of full precision net- works and highlight the benefits of the synergy be- tween neuroscience and deep learning [6]. 1 Centre de Nanosciences et de Nanotechnologies, Uni- versité Paris-Saclay 2 Mila, Université de Montréal Hidden weights as metaplastic states The problem of forgetting in artificial neural net- works results from a dilemma: synapses need to be updated in order to learn new tasks but also to be protected against further changes in order to preserve knowledge. In a foundational neuro- science work, Fusi et al. show than in small Hop- field networks, catastrophic forgetting can be ad- dressed by introducing a hidden metaplastic state that controls the plasticity of the synapse [2]. Synapses can assume only +1 or -1 weight, with the metaplastic state modulating the difficulty for the synapse to switch. Therefore, in this scheme, repeated potentiation of a positive-weight synapse will only affect its metaplastic state and not its actual weight. Here, we remark that the way that BNNs are trained is remarkably similar to this situ- ation. In BNNs, synapses can also only assume +1 or -1 weight, and they feature a hidden real weight (W h ), which is updated by backpropagation. The synaptic weight changes between +1 and -1 only when W h changes sign, suggesting that W h can be seen as a metaplastic state modulating the diffi- culty for the actual weight to change sign. How- ever, standard BNNs are as prone to catastrophic forgetting as conventional neural networks. In [2], Fusi et al. showed that the metaplastic changes should make subsequent affect plasticity exponen- tially to mitigate forgetting, whereas W h affects weight changes only linearly in BNNs. Therefore, in this work, we propose to adapt the learning pro- cess of BNNs so that the larger the magnitude of a hidden weight W h , the more difficult to switch its associated binarized weight W b = sign(W h ). Denoting U W the update provided by the learning algorithm, we implement: W h W h - ηU W · f meta (m, W h ) if U W W h > 0 W h W h - ηU W otherwise. As in the metaplasticity model of [2] where synaptic plasticity decreases exponentially with the metaplastic state, we choose f meta (m, W h )= tanh 0 (m · W h ) to produce an exponential decay for 1 arXiv:2101.07592v1 [cs.NE] 19 Jan 2021

Synaptic metaplasticity in binarized neural networks

  • Upload
    others

  • View
    7

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Synaptic metaplasticity in binarized neural networks

Synaptic metaplasticity in binarized neural networks

Axel Laborieux1, Maxence Ernoult 2, Tifenn Hirtzlin1, Damien Querlioz1

Summary

Unlike the brain, artificial neural networks, includ-ing state-of-the-art deep neural networks for com-puter vision, are subject to “catastrophic forget-ting” [1]: they rapidly forget the previous taskwhen trained on a new one. Neuroscience suggeststhat biological synapses avoid this issue throughthe process of synaptic consolidation and meta-plasticity : the plasticity itself changes upon re-peated synaptic events [2, 3]. In this work, weshow that this concept of metaplasticity can betransferred to a particular type of deep neural net-works, binarized neural networks (BNNs) [4], to re-duce catastrophic forgetting. BNNs were initiallydeveloped to allow low-energy consumption imple-mentation of neural networks. In these networks,synaptic weights and activations are constrainedto {−1,+1} and training is performed using hid-den real-valued weights which are discarded at testtime. Our first contribution is to draw a paral-lel between the metaplastic states of [2] and thehidden weights inherent to BNNs. Based on thisinsight, we propose a simple synaptic consolida-tion strategy for the hidden weight. We justifyit using a tractable binary optimization problem,and we show that our strategy performs almost aswell as mainstream machine learning approaches tomitigate catastrophic forgetting, which minimizetask-specific loss functions [5], on the task of learn-ing pixel-permuted versions of the MNIST digitdataset sequentially. Moreover, unlike these tech-niques, our approach does not require task bound-aries, thereby allowing us to explore a new set-ting where the network learns from a stream ofdata. When trained on data streams from Fash-ion MNIST or CIFAR-10, our metaplastic BNNoutperforms a standard BNN and closely matchesthe accuracy of the network trained on the wholedataset. These results suggest that BNNs are morethan a low precision version of full precision net-works and highlight the benefits of the synergy be-tween neuroscience and deep learning [6].

1Centre de Nanosciences et de Nanotechnologies, Uni-versité Paris-Saclay

2Mila, Université de Montréal

Hidden weights as metaplastic statesThe problem of forgetting in artificial neural net-works results from a dilemma: synapses need tobe updated in order to learn new tasks but alsoto be protected against further changes in orderto preserve knowledge. In a foundational neuro-science work, Fusi et al. show than in small Hop-field networks, catastrophic forgetting can be ad-dressed by introducing a hidden metaplastic statethat controls the plasticity of the synapse [2].Synapses can assume only +1 or −1 weight, withthe metaplastic state modulating the difficulty forthe synapse to switch. Therefore, in this scheme,repeated potentiation of a positive-weight synapsewill only affect its metaplastic state and not itsactual weight. Here, we remark that the way thatBNNs are trained is remarkably similar to this situ-ation. In BNNs, synapses can also only assume +1or −1 weight, and they feature a hidden real weight(W h), which is updated by backpropagation. Thesynaptic weight changes between +1 and −1 onlywhen W h changes sign, suggesting that W h can beseen as a metaplastic state modulating the diffi-culty for the actual weight to change sign. How-ever, standard BNNs are as prone to catastrophicforgetting as conventional neural networks. In [2],Fusi et al. showed that the metaplastic changesshould make subsequent affect plasticity exponen-tially to mitigate forgetting, whereas W h affectsweight changes only linearly in BNNs. Therefore,in this work, we propose to adapt the learning pro-cess of BNNs so that the larger the magnitude ofa hidden weight W h, the more difficult to switchits associated binarized weight W b = sign(W h).Denoting UW the update provided by the learningalgorithm, we implement:

W h ← W h − ηUW · fmeta(m,Wh) if UWW

h > 0

W h ← W h − ηUW otherwise.

As in the metaplasticity model of [2] wheresynaptic plasticity decreases exponentially withthe metaplastic state, we choose fmeta(m,W

h) =tanh

′(m ·W h) to produce an exponential decay for

1

arX

iv:2

101.

0759

2v1

[cs

.NE

] 1

9 Ja

n 20

21

Page 2: Synaptic metaplasticity in binarized neural networks

large metaplastic states W h, where m is an hyper-parameter that controls the consolidation.

Toy problem studyTo validate the interpretation of hidden weightsas metaplastic states, we first focus on a highlysimplified binary optimization task that we solvein a way analogous to the BNN training process.We want the binarized weights W b to minimizea quadratic loss L, as depicted by the color mapon Fig. 1(a) in two dimensions, with W ∗ as theglobal optimum. We assume that W h is updatedby loss gradients computed with binarized weights,similarly to BNNs:

W ht+1 = W h

t − η∂L∂W

(W bt ).

We can show that if the infinite norm of W ∗ islesser than one, some hidden weights diverge ast → ∞. This is because W h is updated by lossgradients computed at the corners of the square,in contrast with conventional optimization. Moreimportantly, if we define importance of the bina-rized weight as the increase of the loss ∆L whenthe weight is switched to the opposite value, wecan prove that the speed of divergence of the hid-den weight is directly linked to the importance ofthe binarized weight. For instance, in Fig. 1(a),W b

x is more important than W by for optimization.

Finally, we plot ∆L versus |W h| in Fig. 1(b), (c)for higher dimensions and for a BNN trained onMNIST and observe that the correspondence be-tween important weights and hidden weight diver-gence still holds, justifying the fact that consolidat-ing synapses with diverging hidden weights as ourproposal does, is a promising route for mitigatingcatastrophic forgetting.

Experimental resultsContinual learning benchmark. We now ap-ply our consolidation strategy to the permutedMNIST benchmark on two hidden layers percep-trons of varying number of neurons. We show inFig. 1(d),(e) the average test accuracy as a func-tion of the number of tasks learned so far. Weobserve that our technique indeed allows sequen-tial task learning and performs almost as well asElastic Weight Consolidation (EWC) [5] adaptedto BNNs (the importance factor is computed withthe binarized weights) over a wide range of hiddenlayer sizes when learning up to 20 tasks. We choosem = 1.35 and λEWC = 5 · 10−3 for EWC.

Learning from a stream of data. By construc-tion, our approach does not require to update theimportance factor between two consecutive tasks.Building on this asset, we explore a new setting,which we call stream learning, and where a taskis learned by learning sub-parts of the full datasetsequentially, with all classes evenly distributed ineach subset. We choose Fashion MNIST (FM-NIST) and CIFAR-10 for our experiments. Thearchitectures used are a perceptron with two hid-den layers of 1,024 units for FMNIST and a VGG-16 convolutional architecture for CIFAR-10. Weplot on Fig. 1(f), (g) the test accuracy reached bythose networks when metaplasticity is used (red)or not (blue). We see that our approach comescloser to the accuracy reached when the full datasetis learned at once (straight lines) than the non-metaplastic counterpart. Overall, these resultshighlight the benefit of metaplasticity models fromneuroscience when applied to machine learning.

Acknowledgement

This work was supported by European ResearchCouncil Starting Grant NANOINFER (reference:715872).

References[1] French, R. M., Trends in cogn. sci. (1999).[2] Fusi et al. Neuron (2005).[3] Abraham, W. C. Nat. Rev. Neurosci. (2008).[4] Courbariaux et al. arXiv:1602.02830 (2016)[5] Kirkpatrick et al. PNAS (2017).[6] Richards et al. Nat. Neurosci. (2019).[7] Zenke et al. PMLR (2017).

2

Page 3: Synaptic metaplasticity in binarized neural networks

Figure 1: (a) Quadratic binarized optimization intwo dimensions. (b-c) Average loss increase whenswitchingW b versus normalized hidden weight, forthe binary quadratic problem (b) and for a BNNon MNIST (c). (d-e) Permuted MNIST benchmarkwith our method (d) and EWC (e), where the xaxis labels the number of learned tasks, with onecolor per network size. Fashion MNIST (f) andCIFAR-10 (g) test accuracy in the stream learningsetting allowed by our approach (red) compared toa standard BNN (blue). Horizontal rules denotefull dataset training baseline.

3