Adaptive Computation Time (ACT) in Neural Networks [2/3]
Part 2: ACT in Residual Networks
Part 1 is here.
The next step in ACT development was done by Michael Figurnov (then an intern at Google and Ph.D. student in Higher School of Economics, now a researcher in DeepMind).
“Spatially Adaptive Computation Time for Residual Networks” by
Michael Figurnov, Maxwell D. Collins, Yukun Zhu, Li Zhang, Jonathan Huang, Dmitry Vetrov, Ruslan Salakhutdinov
Paper: https://arxiv.org/abs/1612.02297
Code: https://github.com/mfigurnov/sact
Video: https://www.youtube.com/watch?v=xp5lLiA-hA8
#2. ResNets with ACT: SACT
The idea is clever: not only we can adaptively choose the network depth during the inference, but we can do choose the depth individually for every part of the image, following the intuition that some parts of the image are “harder” and/or more important than others.
This flavor of ACT is called Spatially Adaptive Computation Time (SACT).
The main difference from the original paper of ACT is that now we apply ACT to feed-forward neural networks (FFN) instead of RNNs. OK, the special type of FFNs called Residual Networks (or ResNets). ResNets are famous for their huge depth (another important but much less known type of very deep network are Highway Networks, they even appeared earlier than ResNets).
ResNets comprise of blocks, each consisting of several residual units. A residual unit has a form F(x) = x + f(x), where the first term is called a shortcut connection and the second term is a residual function. A residual function consists of three convolutional layers: 1×1 layer that reduces the number of channels, 3×3 layer that has an equal number of input and output channels and 1 × 1 layer that restores the number of channels.
A branch is added to the outputs of each residual unit which predicts a halting score, a scalar value in the range [0, 1]. As long as the sum of halting scores for the particular position does not reach one (actually 1.0 minus a small epsilon), computation continues. When the sum reaches this value, the following units of the block are skipped. Other blocks are not skipped, they perform similar computations.
The authors set the halting distribution to be the evaluated halting scores with the last value replaced by a remainder. This ensures that the distribution over the values of the halting scores sums to one. The output of the block is then re-defined as a weighted sum of the outputs of residual units, where the weight of each unit is given by the corresponding probability value.
Finally, a ponder cost is introduced that is the number of evaluated residual units plus the remainder value. Minimizing the ponder cost increases the halting scores of the non-last residual units making it more likely that the computation would stop earlier. The ponder cost is then multiplied by a constant τ (so you still need to tune τ, a time penalty hyperparameter) and added to the original loss function. ACT is applied to each block of ResNet independently with the ponder costs summed.
ACT has several important advantages. First, it adds very few parameters and computations to the base model. Second, it allows calculating the output of the block “on the fly” without storing all the intermediate residual unit outputs and halting scores in memory (this would not be possible if the halting distribution were a softmax of halting scores, as done in soft attention). Third, we can recover a block with any constant number of units l ≤ L by setting h_1 = · · · = h_{l−1} = 0, h_l = 1. Therefore, ACT is a strict generalization of standard ResNet.
Then ACT is applied to each spatial position of the block. Active positions are the spatial locations where the cumulative halting score is less than one (and the computation continues), inactive positions are the positions where computations halted. The evaluation of a block can be stopped completely as soon as all the positions become inactive.
Note that SACT is a more general model than ACT, and, consequently, than standard ResNet.
The ACT and SACT models are general and can be applied to any ResNet architecture (classification, detection, segmentation, etc).
SACT looks like an attention mechanism incorporated into ResNets — we stop processing the part of the image when its features become “good enough”. The SACT maps can be used for interpretation and gaining insights of how does a neural network work.
ACT and SACT models were applied to the image classification task for the ImageNet dataset, and SACT achieves a better FLOPs-accuracy trade-off than ACT by directing computation to the regions of interest. Additionally, SACT improves the accuracy on high-resolution images compared to the ResNet model.
Then SACT model was applied as a feature extractor in the Faster R-CNN object detection pipeline on the COCO dataset, obtaining significantly improved FLOPs-mAP trade-off compared to basic ResNet models.
Authors also tried to compare SACT ponder cost maps with the position of human eye fixations by evaluating them as a visual saliency model on the cat2000 dataset.
#2b. ResNets with Stochastic depth
There are some other interesting lines of research related to ResNets and ACT.
The first is the stochastic depth (“Deep Networks with Stochastic Depth”, https://arxiv.org/abs/1603.09382).
The idea of stochastic depth resembles dropout. To reduce the effective length of a neural network during training, we randomly skip layers entirely. We achieve this by introducing skip connections in the same fashion as ResNets, however, the connection pattern is randomly altered for each minibatch. For each mini-batch we randomly select sets of layers and remove their corresponding transformation functions, only keeping the identity skip connection.
Stochastic depth aims to shrink the depth of a network during training while keeping it unchanged during testing. This differs from ACT/SACT that reduces computations during inference.
It reduces training time substantially and improves the test error significantly on almost all data sets that authors used for evaluation.
Similar to Dropout, stochastic depth can be interpreted as training an ensemble of networks, but with different depths, possibly achieving higher diversity among ensemble members than ensembling those with the same depth.
#2c. ResNets lesion study
Another interesting line of research is treating deep resnets as ensembles of relatively shallow networks (“Residual Networks Behave Like Ensembles of Relatively Shallow Networks”, https://arxiv.org/abs/1605.06431).
Authors propose a novel interpretation of residual networks showing that they can be seen as a collection of many paths of differing lengths. Moreover, residual networks seem to enable very deep networks by leveraging only the short paths during training.
The most relevant to the ACT part, a lesion study, shows ensemble-like behavior in the sense that removing paths from residual networks by deleting layers or corrupting paths by reordering layers only has a modest and smooth impact on performance.
Authors showed that removing downsampling blocks does have a modest impact on performance (peaks in the figure below correspond to downsampling building blocks), but no other block removal lead to a noticeable change. This result shows that to some extent, the structure of a residual network can be changed at runtime without affecting performance.
Moreover, deleting increasing numbers of residual modules or corruption by reordering blocks increases error smoothly. This result is surprising because it suggests that residual networks can be reconfigured to some extent at runtime.
So, the situation resembles the case of Repeat-RNN (see the previous post). Maybe we do not need a complicated (actually not really complicated, but anyway not as simple as just “repeating n times” or “removing n blocks” baselines) adaptive mechanism to have benefits of reduced computations or better performance on the task.
Anyway, a simple baseline for SACT, say LesionResNet, would be a good option.
… to be continued …
Part 3 is here.