We found that the layers of a pretrained large language model (LLM) can be manipulated as separate modules to build a better and even shallower model customized for each test sample. In particular, each layer from a pretrained LLM can be skipped or repeated multiple times as recurrent neural networks (RNN), and stacked with others in arbitrary orders, yielding a chain-of-layers (CoLa) per sample. This compositional space significantly expands the scope of existing works on looped or recurrently pretrained modules, layer pruning, or early-exit networks.
We develop a Monte Carlo Tree Search (MCTS) protocol to explore and identify the optimal CoLa for each sample from math and commonsense reasoning benchmarks. Compared to a static model of a fixed depth, CoLa allows shortcut paths (fast thinking), recurrence of the same layer(s) (slow thinking), and combining both, offering more flexible, dynamic architectures for different inputs. Specifically,
We introduce a new dimension of generalization that turns a static pretrained LLM into dynamic architectures of adaptive depths without training any parameter: for different test samples/tasks, the pretrained layers can be skipped, repeated, and assembled to create better (more accurate and/or shallower) CoLa models without further training.
We develop an MCTS protocol for efficient architecture search of CoLa with adaptive depth
for each sample. In-depth analysis of patterns in the achieved CoLa models sheds critical insights
into the importance and redundancy of layers at different depths of pretrained/finetuned models
of different sizes, which also vary for tasks at different difficulty levels.
We conduct an extensive analysis of the MCTS-optimized CoLa, which leads to two key findings:
(1) For >75% of samples with correct predictions by the original LLM, we can find shorter CoLa, suggesting a large space for improving inference efficiency;
(2) For >60% of samples with originally incorrect predictions, we can identify CoLa achieving correct predictions, suggesting a large space of performance enhancement.
Our results highlight the shortcomings of using a fixed architecture of pre-trained LLMs for inference on different samples and pave the way to unlock the generalization power of test-time depth adaptation.