This method is designed to significantly speed up the previously proposed Forgetting Transformer (FoX) without any performance degradation. FoX adds a forget gate to the Transformer, and the resulting attention mechanism can also be seen as a data-dependent and learnable version of ALiBi, as follows:
The core idea of Adaptive Computation Pruning (ACP) is simple: we don’t need to waste compute on things that we forget. Concretely, if $D_{ij}$ is far below zero (e.g., -1000), then the term $\exp(q_i^\top k_j + D_{ij})$ is likely to be zero after normalization, and thus any computation involved in this term could be pruned. Due to the special structure of the matrix $D$, this can be done by identifying a pruning boundary across the FlashAttention computation grid and only perform the computation on the right of the pruning boundary:
Summary of results:
In this work we focus on pretraining, though in principle it could also be used during inference (i.e., prefilling and decoding)
ACP consistently prunes around 70% of the attention FLOPs and results in a roughly 10%-35% improvement in training throughput, depending on the model size and the context length. Note the training throughput depends on the the latency of the entire network (including MLPs), but ACP only speeds up the attention part. We have not measured the speedup of the attention part alone yet due to technical reasons, but it is reasonable to expect that the speedup should match the FLOP savings (e.g., if 70% of the attention FLOPs are pruned, the attention part is probably around three times faster).
All the speed improvements are achieved without any performance degradation. This is because we dynamically set the threshold for $D_{ij}$ in a way that ensures the total pruned attention weights are bounded by a small number (in practice, bounded by $e^{-10} < 0.00005$).
Code: https://github.com/zhixuan-lin/arctic-fox. We have more results coming in the future. Stay tuned!