Switchable Decision: Dynamic Neural Generation Networks
Switchable Decision: Dynamic Neural Generation Networks
Switchable Decision: Dynamic Neural Generation Networks
Abstract Fan et al., 2019; Gordon et al., 2020), quantizing the number
Auto-regressive generation models achieve com- of bits needed (Lin et al., 2016; Shen et al., 2020), distilling
petitive performance across many different NLP from large teacher models to small student models (Hinton
arXiv:2405.04513v1 [cs.CL] 7 May 2024
tasks such as summarization, question answer- et al., 2015; Jiao et al., 2019). These methods produce only
ing, and classifications. However, they are also one small model with a predetermined target size. Another
known for being slow in inference, which makes direction is to switch the model parameters for different
them challenging to deploy in real-time applica- data instances, e.g., the mixture of experts (Shazeer et al.,
tions. We propose a switchable decision to ac- 2017), and switch transformer (Fedus et al., 2021). Early
celerate inference by dynamically assigning com- exiting, which adaptively produces a series of small models
putation resources for each data instance. Au- for different data instances, is one of the most common prac-
tomatically making decisions on where to skip tices. Most previous work makes exit decisions based on
and how to balance quality and computation cost either the confidence of output probability distributions or a
with constrained optimization, our dynamic neu- trained agent. In this work, we propose a carefully designed
ral generation networks enforce the efficient infer- candidate space for encoder-decoder auto-regressive models
ence path and determine the optimized trade-off. and enhance the optimization strategies when training the
Experiments across question answering, summa- agent.
rization, and classification benchmarks show that In this spirit, we explore the problem of dynamically allo-
our method benefits from less computation cost cating computation across a generation model. In partic-
during inference while keeping the same accu- ular, we consider a standard encoder-decoder transformer
racy. Extensive experiments and ablation studies auto-regressive generation model. It comprises a stacked
demonstrate that our method can be general, ef- structure with multiple layers, each having a multi-head
fective, and beneficial for many NLP tasks. attention layer followed by a feed-forward network (FFN)
layer (Zhang et al., 2021b;a; Dai et al., 2022; Tanwisuth
et al., 2023). To this end, we introduce a dynamic neural
1. Introduction network for the auto-regressive generation models, which
includes the attention, feed-forward, and input sequence as
Large-scale pre-trained language models such as BART
the candidate space for switchable decisions. Our method
(Lewis et al., 2019) have demonstrated a significant per-
generates an input-dependent inference strategy for each
formance gain to the natural language processing (NLP)
data. For each input sequence, the reinforcement learning
community but generally come with the cost of a heavy
agent outputs all the decisions for skipping or keeping each
computational burden. Besides pre-training and fine-tuning,
candidate. With the first-layer hidden representations as the
inference of such a large model also comes with a heavy
input, the policy network is trained to maximize a reward
computational cost. On IoT (Internet of things) devices
that incentives the use of as few blocks or tokens as possible
and real-world applications, lower computation cost toler-
while preserving the prediction accuracy.
ance and restricted computation resource during inference
impede these models from deployment. We propose learning optimal switchable strategies that si-
multaneously preserve prediction accuracy and minimal
Recent efforts of efficient inference mainly focus on pruning
computation usage based on input-specific decisions. The
or compressing the model parameters, e.g., pruning unim-
constrained optimization is utilized as a more principled
portant parts of the neural model weights (Han et al., 2015b;
approach for trading off these two targets (quality v.s. ef-
1
The University of Texas at Austin. Correspondence to: Shujian ficiency). We target keeping the predicted quality while
Zhang <[email protected]>. achieving better efficiency as far as possible. A gradient-
based constrained optimization algorithm is implemented
Proceedings of the 41 st International Conference on Machine under our framework.
Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by
the author(s).
1
Switchable Decision: Dynamic Neural Generation Networks
We run extensive experiments across summarization, e.g., graph. Dynamic jumping (Yu et al., 2018; Fu & Ma, 2018)
XSum (Narayan et al., 2018) and CNN/DM (Hermann et al., strategically skips some tokens without reading them, and
2015), question answering, e.g., SQuAD 1.1 (Rajpurkar directly jumps to an arbitrary location. Early exiting for
et al., 2016) and SQuAD 2.0 (Rajpurkar et al., 2018)), and pretrained models has been explored by previous literature.
GLUE (Wang et al., 2018a) classification tasks. ❶ Our RTJ (Schwartz et al., 2020), DeeBERT (Xin et al., 2020),
method not only shows comparable performance across and FastBERT (Liu et al., 2020a) make early exiting deci-
different tasks and datasets but also accelerates model infer- sions based on confidence (or its variants) of the predicted
ence by up to 40% with negligible model quality degrada- probability distribution and are therefore limited to classifi-
tion. ❷ Furthermore, we provide extensive ablation studies cation tasks. PABEE (Zhou et al., 2020) and BERxiT (Xin
on different design choices for the proposed method, includ- et al., 2021) propose patience-based early exiting by exploit-
ing the encoder-only or decoder-only switchable schemes. ing the layer information. Runtime Neural Pruning (Lin
❸ Our analysis shows the switchable decision contributes et al., 2017), SkipNet (Wang et al., 2018b), and BlockDrop
the efficiency improvement and accuracy consistency, help- (Wu et al., 2018) use reinforcement learning (RL) to decide
ing the generation model to choose the inference path and whether to execute a network module. Inspired by them,
candidates dynamically. ❹ To the best of our knowledge, we incorporate lightweight reinforcement learning to make
we present the first switchable decision in the language input-dependent decisions and build a diversified switch-
generation model setting by dynamically making the infer- able candidate space. With the constrained optimization
ence decisions in summarization, question answering, and approach, our method saves computational costs without
classification. Our contributions are summarized as follows: loss of accuracy.
2
Switchable Decision: Dynamic Neural Generation Networks
Figure 1. Overview of the dynamic network. Some notations are labeled along with corresponding components. ‘Layers’ refers to layers
within the auto-regressive generation model. ‘ATT’ refers to the attention candidate, ‘FFN’ refers to the feed-forward candidate, ‘Text
Input’ refers to the token candidate, and ‘Decisions’ refers to the skipping decisions from the reinforcement learning agent. The green
color represents not skipping. The no-fill in the text input and the dashed line with the no-fill color box represents the skipping.
(2019) discover that some layers are redundant. To decide same as g(·), and the output of h(·) is a distribution over all
whether to skip a certain layer, we model these decisions as a six candidate decisions.
sequence of i.i.d. Bernoulli random variables parameterized
by a policy network q. Let bl denote the switchable decision Encoder and Decoder Structure. Our interested architec-
of the lth layer, defined as ture contains encoders and decoders. For the encoders, we
( apply attention skipping and feed-forward skipping together
1 with probability g(x)l with token skipping. For the decoders, since every token
bl = , (1)
0 with probability 1 − g(x)l is meaningful for the final outputs, we only apply attention
skipping and feed-forward skipping. When making deci-
where x ∈ Re denotes the input of the decision unit, and sions, we sample from the outputs of our policy network,
we apply the first encoder layer output as x. The policy and broadcast the decisions to the hidden representations of
network, g, learns instance-specific probabilities of keep- each layer.
ing the hidden representations of each layer. To perform
skipping, we sample from this distribution and broadcast 3.2. Reinforcement Learning Agent
the indicators, batt
l , to the input representations of attention Policy Network Architecture. Since we aim to speed
layers.
up the inference process, a simple design for the policy
network is adopted. We utilize a one-layer MLP with layer
Feed-Forward Candidate. In the same spirit, the feed- normalization and ReLU activation function. To output a
forward layers may contain redundant information. Thus, Binomial distribution over decisions, we apply the sigmoid
we consider skipping these layers using the same approach activation to the outputs of the network for attention and
as that done in the attention. We decide whether to skip or feed-forward candidates. We use the softmax function to
not based on the indicator bffnl . The design of the policy output the distribution over the choices for token candidates.
network is the same as that of the attention layer.
Parameterization. During the training process, we sam-
Token Candidate. In addition to skipping the layers, skip- ple from the decision distributions, which are parameterized
ping the tokens can also be an alternative way to save com- by the policy network. The distribution of the switchable de-
putation costs. We create two token skipping strategies: ➀ cisions for the layers can be represented as a 2L-dimensional
skipping the last p% tokens and ➁ uniformly skipping p% Bernoulli distribution, which can be written as:
tokens. For the former, we set p to 10, 20, and 30. For
the latter, p is equal to 25, 33, and 50. To decide which 2L
Y
strategy to use, we optimize a categorical random variable π(s | x) = gl (x)sl (1 − gl (x))1−sl , (2)
parameterized by a function h(·). The input of h(·) is the l=1
3
Switchable Decision: Dynamic Neural Generation Networks
where s = {batt L
S ffn L
l }l=1 {bl }l=1 . Similarly, the distribu- Algorithm 1 Switchable Decision (SD)
tion of the token skipping decisions can be represented as a 1: Input: Text o. Auto-regressive generation model M
categorical distribution, which can be formalized as: parameter w with learning rate αt , policy network pa-
rameter θ with learning rate γt , number of iterations T .
J
Y
η(a | x) = hj (x)1(a=j) , (3) 2: for t = 0 to T do
j=1
3: w ← w − αt ∇(w),
where a denotes the choice of the skipping strategy, and 4: θ is updated via Eqn (7),
J indicates the total number of strategies. We apply seven 5: end for
candidates in practice.
Our Equation. To optimize the trade-off between quality
Reward. We define the reward function (Yang et al., and computation in Eqn (4), we propose to use lexicographic
2022b;a; Feng et al., 2023) as a trade-off between qual- optimization, in which the parameters are iteratively updated
ity and computational cost. Given an inference path and a as
data instance, the reward can be computed from the compu-
tation (estimated FLOPs). Intuitively skipping layers will θt+1 ← θt − γt et , (6)
have high reward. We further refer quality as accuracy and where γt ≥ 0 is an adaptive step size and et ∈ Rd is an
loss in the following way: update direction to be chosen to balance the minimization
of f and constraint satisfaction on q. One of the objectives
R(s, a) = quality + λcomputation, (4)
(say f which is computation in our case) is of secondary
where quality is −loss, computation is the estimated importance w.r.t. the other one (say q which is quality). The
FLOPs (floating point operations), and λ is a coefficient. design criterion for the constrained optimization is when
The overall loss function is defined as the expected value of the constraint is not satisfied (i.e., q(θt ) ≥ c), the focus
the reward: becomes decreasing q to satisfy the constraint as soon as
possible; in the meantime, f performs as a secondary ob-
jective indicating that f should be minimized to the degree
that it does not hurt the descent of q. Therefore, we apply
J = Es∼π, t∼η [R(s, a)], (5) the following update rule to obtain such a goal:
where π and η are defined in (2) and (3), respectively. θt+1 ← θt −γt (∇quality
+ λ∇computation (θt )), (7)
Optimization. To optimize our policy network, we apply
policy gradient to compute the gradient of J, and update the where ∇computation and ∇quality are estimated
parameters of the policy network. We use a self-critical base- by score function, and the λ can be computed as
ϕ(θt )−∇quality(θt )⊤ ∇computation(θt )
line to reduce the variance of the gradients. The constraint- λ = max ∥∇computation(θ )∥2
, 0 ,
t
optimization strategy is further applied on the quality and where ϕ(θt ) equals to q(θt ) − c and the c represents the
computation. Details are in the next section. minimal loss.
During Inference. Unlike the training process, we do not The Proposed Algorithm. Our switchable decision (SD)
sample the skipping decisions during inference. Instead, with efficient candidate space and constrained optimization
we choose the decisions which maximize the likelihood is shown in Algorithm 1. We iteratively update the auto-
function. regressive model and the policy network in a single-loop
manner. The policy network parameter θ is updated by
3.3. Constrained Optimization Eqn (6) in a direction to balance the optimization of quality
and constraint satisfaction on computation.
Trade-off is a Problem. In the joint training of the main
network and the policy network, a trade-off between qual-
ity and computation is important. The linear combination 4. Experimental Settings
of multiple objectives is the most widely used approach. Table 1 shows the experimental data configuration.
However, the coefficient of the combination requires man-
ual tuning, and it is theoretically unsuitable for non-convex
4.1. Task and Evaluation Metrics
functions. In this work, we consider constrained optimiza-
tion on trading off two objectives, with a special emphasis Summarization. We use CNN/DailyMail (Hermann et al.,
on lexicographic (lexico) optimization. 2015) and XSum (Narayan et al., 2018) to evaluate our
4
Switchable Decision: Dynamic Neural Generation Networks
method. CNN/DailyMail consists of 287,226 documents source input and auto-regressive target output, which con-
for training, 13,368 documents for validation, and 11,490 tains 12 layers of transformer encoder and 12 layers of
documents for testing. XSum has 226,711 news articles transformer decoder. Its embedding size is 1,024 and feed-
accompanied with a one-sentence summary, answering the forward size is 4,096. We follow the hyper-parameters used
question “What is this article about?”. Following the splits in Lewis et al. (2019). Specifically, in summarization, we
of Narayan et al. (2018), it contains 204,045 train, 11,332 set the training steps as 50k and the number of warm-up
dev, and 11,334 test. Following prior work (Lewis et al., steps as 500. The max number of tokens and the update fre-
2019), we use ROUGE (Lin & Hovy, 2003) as our primary quency are set to be 2,048 and 4, respectively. The learning
metric. We report the unigram ROUGE1 (R-1) and bigram rate is set to 3 × 10−5 . For the question answering (SQuAD
ROUGE-2 (R-2) overlap to assess the informativeness, and 1.1/2.0). We set the total number of updates and warm-up
the longest common subsequence ROUGE-L (R-L) score to updates as 5,430 and 326, respectively. The max number
assess the fluency. of sentences is 3 per device with an update frequency of 2.
The learning rate is 1.5 × 10−5 . We refer the readers to Ap-
Question Answering. The Stanford Question Answering pendix A for classification hyper-parameter configurations,
Datasets (SQuAD) v1.1 and v2.0 (Rajpurkar et al., 2016; and more details about the settings.
2018; Fan et al., 2020) are popular machine reading com-
prehension benchmarks. For the SQuAD v2.0 dataset, it 5. Experiments
contains examples where the answer to the question cannot
be derived from the provided context. Similar to previous We evaluate the performance of our switchable dynamic
settings (Devlin et al., 2018; Lewis et al., 2019), we use network. In each table, we bold the best result within each
concatenated question and context as input to the encoder of column block and the results of our method are obtained
BART, and additionally pass them to the decoder. We report with three trials to determine the variance. See Appendix A
Exact Match (EM) and F1 score for evaluation (Lewis et al., for full results with error bars.
2019).
5.1. Summarization
Classification. The General Language Understanding
Table 2 reports our results on two summarization datasets.
Evaluation (GLUE) benchmark is a collection of natural
➀ The top block displays the performance of baselines
language understanding (NLU) tasks. As shown in Table
on CNN/DailyMail and XSum datasets, and the bottom
1, we include Multi-Genre NLI (MNLI; (Williams et al.,
block shows the results of incorporating the switchable dy-
2017b; Zhang et al., 2021d)), Recognizing Textual Entail-
namic networks. We report the results upon the BART
ment (RTE; (Dagan et al., 2005)), and Stanford Sentiment
large setting in Lewis et al. (2019). ➁ Summaries in the
Treebank (SST; (Socher et al., 2013)). The diversity of the
CNN/DailyMail tend to resemble source sentences and sum-
tasks makes GLUE very suitable for evaluating the general-
maries in XSUM are highly abstractive. Baseline mod-
ization and robustness of our proposed method (Liu et al.,
els such as BART (Lewis et al., 2019), UniLM (Dong
2020b). Accuracy is adopted as our evaluation metric.
et al., 2019), and BERTSUM (Liu & Lapata, 2019) do
well enough, and even the baseline of the first-three source
Task Dataset Train Val Test
sentences is highly competitive for CNN/DailyMail. Our
CNN/DailyMail 287.2K 13.4K 11.5k method can reduce the computation cost while having little
Summarization
XSum 204K 11.3K 11.3K
SQuAD 1.1 87.6K 10.5K 9.5k or no drop on ROUGE. For example, we even have a 0.2
Question Answering increase on R1 for CNN/DailyMail and a 0.1 increase on
SQuAD 2.0 130.3K 11.9K 8.9K
RTE 2.5K 276 3k R1 for XSum, while reducing 39% and 18% computation
Classification MNLI 393K 20K 20K costs, respectively. For the quality of the sentence gener-
SST 67K 872 1.8K
ations, our method has almost outperformed all the base-
lines. Especially, for the CNN/DailyMail, we achieve better
Table 1. Dataset Configuration. The top block is for summariza- ROUGE with less than two-thirds FLOPs cost, compared to
tion, the middle block is for question answering, and the bottom the original BART-large model (e.g., R1: 44.16 → 44.31,
block is the classification tasks. RL: 40.90 → 41.01 on CNN/DailyMail). ➂ These results
further confirm that SD can work as an effective module
4.2. Implementation Details to be incorporated into the auto-regressive generation mod-
els. SD on improving the inference can also be seen as a
Following Lewis et al. (2019), we take the pre-trained BART complementary module to works focusing on improving
model as the backbone and utilize the provided checkpoint pre-training components (Hou et al., 2022; Ge et al., 2022).
for finetuning on the downstream datasets. BART is a pre-
trained sequence-to-sequence model based on the masked
5
Switchable Decision: Dynamic Neural Generation Networks
6
Switchable Decision: Dynamic Neural Generation Networks
Model
SQuAD 1.1 SQuAD 2.0 percentage of each candidate. These results confirm our
EM/F1 ↑ FLOPs (%) ↓ EM/F1 ↑ FLOPs (%) ↓ analysis and motivation for the switchable decision that us-
BERT 84.1/90.9 - 79.0/81.8 -
ing a combination of all these architectural search spaces
UniLM -/- - 80.5/83.4 -
RoBERTa 88.9/94.6 - 86.5/89.4 - comes to the best efficiency and accuracy trade-off.
BART 88.8/94.6 100 86.1/89.2 100
Ours 88.7/94.5 80.5 86.0/89.3 83.3 Architecture ATT FFN Token FLOPs (%) ROUGE
BART 100 44.16/21.28/40.90
Encoder-Only ✓ ✓ 91.9 44.21/21.32/40.95
Table 5. Results across different strategies on SQuAD v1.1 and Decoder-Only ✓ ✓ 90.3 44.13/21.08/40.86
v2.0. Answers are text spans extracted from a given document Token-Only ✓ 71.5 44.09/21.26/40.92
context. 2 Ours ✓ ✓ ✓ 61.1 44.31/ 21.18/41.01
cussed in Section 3, our proposed method targets the auto- Table 7. Results of skipping strategies on different architecture
regressive generation model. Thus, can our method be spaces for CNN/DailyMail. BART (Izacard & Grave, 2021) large
model is presented.
adapted to other auto-regressive generation models? We
select the GPT-2 (Radford et al., 2019) base and T5 (Raffel
et al., 2020) base to study the performance after adapting our Ablation studies on the components in SD. We con-
proposed switchable decisions. The results are presented in duct the ablation study to examine the role of constrained
Table 6. It indicates our method is insensitive to different optimization. For ablation, instead of automatically search-
generation models. This confirms our discussion in Sec- ing the trade-off between the quality and computation, we
tion 3 that SD can serve as an efficient alternative dynamic manually set the λ in Eqn (7) as 0.2, 0.5, 0.8. We also
network for versatile generation models. We also analyze include the random selection strategy. The random selec-
the impact of making decisions based on different hidden tion strategy is not learning switchable decisions and would
representations. More details about LLaMA (Touvron et al., not dynamically assign computation for each data instance.
2023) models are included in Appendix A. ❶ Table 8 shows that the constrained optimization of our
method brings clear benefits. ❷ We find that without CO,
Data ROUGE FLOPs (%) ‘− CO’ with different manually tuned λ value shows an un-
BART 44.16/21.28/40.90 100 stable trade-off between the ROUGE and FLOPs across
+ Ours 44.31/21.18/41.01 61.1 all λ values, indicating that manually tuned λ value can
GPT-2 37.55/15.53/25.81 100 not bring both optimized quality and computation together.
+ Ours 37.76/15.68/25.93 74.5 ❸Empirically, we randomly select a policy from our deci-
T5 42.05/20.34/39.40 100
sion space candidates and use the same other parameters.
+ Ours 41.98/20.38/39.61 74.5
These result in a degradation in performance and lower
Table 6. The proposed method for different generation models on
FLOPs reduction. It demonstrates the necessity and effec-
CNN/DailyMail. tiveness of the constrained optimization for the switchable
candidate set in SD structure.
What are the differences between encoder-only, decoder- Data ROUGE FLOPs (%)
only, and token-only architecture search space? We BART 44.16/21.28/40.90 100
Random 41.77/19.02/38.72 75.3
test if our results are sensitive to the choice of architectures: Ours 44.31/21.18/41.01 61.1
encoder-only, decoder-only, and encoder-decoder. We create - CO, λ = 0.2 44.12/21.30/40.88 77.8
the following scenarios: ① For encoder-only, we incorporate - CO, λ = 1.0 42.89/21.02/40.57 68.5
the attention and feed-forward as the skipping candidates. ② - CO, λ = 1.5 41.35/19.87/38.39 49.4
7
Switchable Decision: Dynamic Neural Generation Networks
training time of SD are slightly higher (2.7% for memory The impact of making decisions based on different hid-
and 1.6% for running time) than BART. SD gives the best den representations. In Section 3.1, we consider three
inference FLOPs, outperforming BART while keeping the skipping candidates’ hidden representations (attention, feed-
comparable ROUGE score and running time. ❷ For the forward, and query) after the first layer as the input for our
inference time, we evaluate our method and BART large reinforcement learning agent to make switchable decisions.
on CNN/DailyMail following the same setting and device Here, we demonstrate that using hidden representations
with batch size 1. For each iteration, 5.1 seconds (Ours) vs. from different layers comes to the same results, and there-
10.3 seconds (BART). Our dynamic network demonstrates fore we pick the easiest one. We set up a baseline here, in
the strong capability of making skipping decisions. ❸ With which whether to skip the following layer is dependent on
the constrained optimization and the reinforcement learn- the nearby previous layer outputs. We experiment on Ours
ing agent, our switchable decision is still computationally (based on the output from the first layer) and Ours Layer
productive as the design of our optimization and agent (e.g., Wise (layer-wise decisions based on the output from the
applying one-layer MLP for policy network) has almost nearby previous layers). The difference between these two
negligible finetuning computational cost. cases is small in Table 11. The layer-wise design requires
more computation as it needs to make decisions at each
Model ROUGE ↑ Params ↓ GPU memory ↓ s/step ↓ IT ↓
layer. Therefore, it further demonstrates that the design of
BART 44.16/21.28/40.90 406M 16.8G 1.20 10.3 ours is capable of making skipping decisions and imposing
Ours 44.31/21.18/41.01 423M 17.6G 1.48 5.1 less computational cost.
Table 9. Results of parameter size, GPU memory per device, and Data ROUGE FLOPs (%)
step time for BART and ours finetuning on CNN/DailyMail. Ours 44.31/21.18/41.01 61.1
‘s/step’ represents training step time (second/per step).‘IT’ rep- Ours Layer Wise 44.38/21.22/40.97 61.8
resents inference time (second) for each iterations.
Table 11. Comparison of different layer-wise decision of SD on
CNN/DailyMail. ‘Ours’ represents the decision based on the hid-
6.1. Contributions of Search Space Candidates. den after the first layer. ‘Ours Layer Wise’ represents the decision
based on the hidden representation from the nearby previous layer.
To further identify the contributions of our search space
candidates for efficiency improvements and inference ac-
celeration, we present the details skipping percentage of 7. Conclusion
each candidate for CNN/DailyMail, SQuAD 1.1, and SST
in Table 10. For CNN/DailyMail, we observe around 8% Our work demonstrates the benefits of introducing a switch-
attention skipping of total attention, 11% feed-forward skip- able decision of the dynamic network. The proposed method
ping of total feed-forward, and 29% token skipping of total can dramatically increase the inference efficiency and still
tokens. The similar skipping percentage holds for question enable the model performance. Noticeable FLOPs saving
answering. However, we have seen an obvious contrast in and consistent performance are observed across summariza-
the token skipping percentage in classification tasks. The tion, question answering, and classification benchmarks. We
key observation is that the skipping percentages for tokens further conduct a detailed study with the proposed switch-
are high for both CNN/DailyMail and SQuAD 1.1. In ad- able strategy in different settings, e.g., comparing with dif-
dition, our method generally takes around 5K iterations ferent architecture search spaces, providing more evidence
for the reinforcement learning algorithm to converge on for making decisions based on hidden representations, and
CNN/DailyMail. This confirms our conjecture in Section verifying the impact of components. To summarize, the
5.1. For summarization and question answering tasks, the proposed SD is effective and general, with the potential to
first few parts of inputs are more representative. Thus, it be incorporated into existing generation models for various
perfectly serves as the candidate for our switchable network NLP tasks.
to make the skipping decisions.
8
Switchable Decision: Dynamic Neural Generation Networks
References Fu, T.-J. and Ma, W.-Y. Speed reading: Learning to read
forbackward via shuttle. In Proceedings of the 2018
Bae, S., Ko, J., Song, H., and Yun, S.-Y. Fast and robust
Conference on Empirical Methods in Natural Language
early-exiting framework for autoregressive language mod-
Processing, pp. 4439–4448, 2018.
els with synchronized parallel decoding. arXiv preprint
arXiv:2310.05424, 2023. Ge, T., Xia, H., Sun, X., Chen, S.-Q., and Wei, F. Loss-
less acceleration for seq2seq generation with aggressive
Campos, V., Jou, B., Giró-i Nieto, X., Torres, J., and Chang,
decoding. arXiv preprint arXiv:2205.10350, 2022.
S.-F. Skip rnn: Learning to skip state updates in recurrent
neural networks. arXiv preprint arXiv:1708.06834, 2017. Gordon, M. A., Duh, K., and Andrews, N. Compressing
bert: Studying the effects of weight pruning on transfer
Chen, G., Choi, W., Yu, X., Han, T., and Chandraker, M.
learning. arXiv preprint arXiv:2002.08307, 2020.
Learning efficient object detection models with knowl-
edge distillation. Advances in neural information process- Han, S., Mao, H., and Dally, W. J. Deep compres-
ing systems, 30, 2017. sion: Compressing deep neural networks with pruning,
trained quantization and huffman coding. arXiv preprint
Dagan, I., Glickman, O., and Magnini, B. The pascal recog-
arXiv:1510.00149, 2015a.
nising textual entailment challenge. In Machine Learning
Challenges Workshop, pp. 177–190. Springer, 2005. Han, S., Pool, J., Tran, J., and Dally, W. Learning both
weights and connections for efficient neural network.
Dai, Y., Tang, D., Liu, L., Tan, M., Zhou, C., Wang, J., Feng, Advances in neural information processing systems, 28,
Z., Zhang, F., Hu, X., and Shi, S. One model, multiple 2015b.
modalities: A sparsely activated approach for text, sound,
image, video and code. arXiv preprint arXiv:2205.06126, Hansen, C., Hansen, C., Alstrup, S., Simonsen, J. G., and
2022. Lioma, C. Neural speed reading with structural-jump-
lstm. arXiv preprint arXiv:1904.00761, 2019.
Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K.
Bert: pre-training of deep bidirectional transformers Hermann, K. M., Kocisky, T., Grefenstette, E., Espeholt,
for language understanding. arxiv. arXiv preprint L., Kay, W., Suleyman, M., and Blunsom, P. Teaching
arXiv:1810.04805, 2018. machines to read and comprehend. Advances in neural
information processing systems, 28, 2015.
Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. BERT:
Pre-training of deep bidirectional transformers for lan- Hinton, G., Vinyals, O., Dean, J., et al. Distilling
guage understanding. In NAACL-HLT, 2019. the knowledge in a neural network. arXiv preprint
arXiv:1503.02531, 2(7), 2015.
Dong, L., Yang, N., Wang, W., Wei, F., Liu, X., Wang, Y.,
Gao, J., Zhou, M., and Hon, H.-W. Unified language Hou, L., Pang, R. Y., Zhou, T., Wu, Y., Song, X., Song, X.,
model pre-training for natural language understanding and Zhou, D. Token dropping for efficient bert pretraining.
and generation. Advances in Neural Information Process- arXiv preprint arXiv:2203.13240, 2022.
ing Systems, 32, 2019.
Izacard, G. and Grave, E. Distilling knowledge from reader
Fan, A., Grave, E., and Joulin, A. Reducing transformer to retriever for question answering. In ICLR 2021, 9th
depth on demand with structured dropout. arXiv preprint International Conference on Learning Representations,
arXiv:1909.11556, 2019. 2021.
Fan, X., Zhang, S., Chen, B., and Zhou, M. Bayesian atten- Jiao, X., Yin, Y., Shang, L., Jiang, X., Chen, X., Li,
tion modules. arXiv preprint arXiv:2010.10604, 2020. L., Wang, F., and Liu, Q. Tinybert: Distilling bert
for natural language understanding. arXiv preprint
Fedus, W., Zoph, B., and Shazeer, N. Switch transform- arXiv:1909.10351, 2019.
ers: Scaling to trillion parameter models with simple and
efficient sparsity, 2021. Kingma, D. P. and Ba, J. Adam: A method for stochastic
optimization. arXiv preprint arXiv:1412.6980, 2014.
Feng, Y., Yang, S., Zhang, S., Zhang, J., Xiong, C., Zhou,
M., and Wang, H. Fantastic rewards and how to tame Lan, Z., Chen, M., Goodman, S., Gimpel, K., Sharma,
them: A case study on reward learning for task-oriented P., and Soricut, R. Albert: A lite bert for self-
dialogue systems. arXiv preprint arXiv:2302.10342, supervised learning of language representations. ArXiv,
2023. abs/1909.11942, 2020.
9
Switchable Decision: Dynamic Neural Generation Networks
Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mo- Narayan, S., Cohen, S. B., and Lapata, M. Don’t give me the
hamed, A., Levy, O., Stoyanov, V., and Zettlemoyer, L. details, just the summary! topic-aware convolutional neu-
Bart: Denoising sequence-to-sequence pre-training for ral networks for extreme summarization. arXiv preprint
natural language generation, translation, and comprehen- arXiv:1808.08745, 2018.
sion. arXiv preprint arXiv:1910.13461, 2019.
Radford, A., Wu, J., Child, R., Luan, D., Amodei, D.,
Li, Z., Wang, Z., Tan, M., Nallapati, R., Bhatia, P., Arnold, Sutskever, I., et al. Language models are unsupervised
A., Xiang, B., and Roth, D. Dq-bart: Efficient sequence- multitask learners. OpenAI blog, 1(8):9, 2019.
to-sequence model via joint distillation and quantization.
arXiv preprint arXiv:2203.11239, 2022. Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S.,
Matena, M., Zhou, Y., Li, W., and Liu, P. J. Exploring the
Lin, C.-Y. and Hovy, E. Automatic evaluation of summaries limits of transfer learning with a unified text-to-text trans-
using n-gram co-occurrence statistics. In Proceedings of former. J. Mach. Learn. Res., 21:140:1–140:67, 2020.
the 2003 human language technology conference of the
North American chapter of the association for computa- Rajpurkar, P., Zhang, J., Lopyrev, K., and Liang, P. Squad:
tional linguistics, pp. 150–157, 2003. 100,000+ questions for machine comprehension of text.
Empirical Methods in Natural Language Processing
Lin, D., Talathi, S., and Annapureddy, S. Fixed point quan- (EMNLP), 2016.
tization of deep convolutional networks. In International
conference on machine learning, pp. 2849–2858. PMLR, Rajpurkar, P., Jia, R., and Liang, P. Know what you don’t
2016. know: Unanswerable questions for squad. Annual Meet-
ings of the Association for Computational Linguistics
Lin, J., Rao, Y., Lu, J., and Zhou, J. Runtime neural pruning. (ACL), 2018.
Advances in neural information processing systems, 30,
2017. Schuster, T., Fisch, A., Gupta, J., Dehghani, M., Bahri, D.,
Tran, V., Tay, Y., and Metzler, D. Confident adaptive
Liu, W., Zhou, P., Zhao, Z., Wang, Z., Deng, H., and Ju,
language modeling. Advances in Neural Information
Q. Fastbert: a self-distilling bert with adaptive inference
Processing Systems, 35:17456–17472, 2022.
time. arXiv preprint arXiv:2004.02178, 2020a.
Schwartz, R., Stanovsky, G., Swayamdipta, S., Dodge, J.,
Liu, X., Wang, Y., Ji, J., Cheng, H., Zhu, X., Awa, E., He, P.,
and Smith, N. A. The right tool for the job: Match-
Chen, W., Poon, H., Cao, G., et al. The microsoft toolkit
ing model and instance complexities. arXiv preprint
of multi-task deep neural networks for natural language
arXiv:2004.07453, 2020.
understanding. arXiv preprint arXiv:2002.07972, 2020b.
Liu, X., Gong, C., Wu, L., Zhang, S., Su, H., and Liu, Q. Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le,
Fusedream: Training-free text-to-image generation with Q., Hinton, G., and Dean, J. Outrageously large neural
improved clip+ gan space optimization. arXiv preprint networks: The sparsely-gated mixture-of-experts layer.
arXiv:2112.01573, 2021. arXiv preprint arXiv:1701.06538, 2017.
Liu, Y. and Lapata, M. Text summarization with pretrained Shen, S., Dong, Z., Ye, J., Ma, L., Yao, Z., Gholami, A.,
encoders. arXiv preprint arXiv:1908.08345, 2019. Mahoney, M. W., and Keutzer, K. Q-bert: Hessian based
ultra low precision quantization of bert. In Proceedings
Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D., of the AAAI Conference on Artificial Intelligence, vol-
Levy, O., Lewis, M., Zettlemoyer, L., and Stoyanov, V. ume 34, pp. 8815–8821, 2020.
Roberta: A robustly optimized bert pretraining approach.
ArXiv, abs/1907.11692, 2019a. Shleifer, S. and Rush, A. M. Pre-trained summarization
distillation. arXiv preprint arXiv:2010.13002, 2020.
Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D.,
Levy, O., Lewis, M., Zettlemoyer, L., and Stoyanov, V. Socher, R., Perelygin, A., Wu, J., Chuang, J., Manning,
Roberta: A robustly optimized bert pretraining approach. C. D., Ng, A. Y., and Potts, C. Recursive deep models for
arXiv e-prints, pp. arXiv–1907, 2019b. semantic compositionality over a sentiment treebank. In
Proceedings of the 2013 conference on empirical methods
Manning, C. D., Surdeanu, M., Bauer, J., Finkel, J. R., in natural language processing, pp. 1631–1642, 2013.
Bethard, S., and McClosky, D. The stanford corenlp
natural language processing toolkit. In Proceedings of Sun, Z., Yu, H., Song, X., Liu, R., Yang, Y., and Zhou, D.
52nd annual meeting of the association for computational Mobilebert: a compact task-agnostic bert for resource-
linguistics: system demonstrations, pp. 55–60, 2014. limited devices. arXiv preprint arXiv:2004.02984, 2020.
10
Switchable Decision: Dynamic Neural Generation Networks
Tanwisuth, K., Zhang, S., Zheng, H., He, P., and Zhou, Yu, K., Liu, Y., Schwing, A. G., and Peng, J. Fast and
M. Pouf: Prompt-oriented unsupervised fine-tuning for accurate text classification: Skimming, rereading and
large pre-trained models. In International Conference on early stopping. 2018.
Machine Learning, pp. 33816–33832. PMLR, 2023.
Zhang, C., Bengio, S., and Singer, Y. Are all layers created
Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, equal? 2019.
M.-A., Lacroix, T., Rozière, B., Goyal, N., Hambro, E.,
Azhar, F., et al. Llama: Open and efficient foundation lan- Zhang, S., Fan, X., Chen, B., and Zhou, M. Bayesian
guage models. arXiv preprint arXiv:2302.13971, 2023. attention belief networks. In International Conference on
Machine Learning, pp. 12413–12426. PMLR, 2021a.
Wang, A., Singh, A., Michael, J., Hill, F., Levy, O., and
Bowman, S. R. Glue: A multi-task benchmark and anal- Zhang, S., Fan, X., Zheng, H., Tanwisuth, K., and Zhou,
ysis platform for natural language understanding. arXiv M. Alignment attention by matching key and query dis-
preprint arXiv:1804.07461, 2018a. tributions. Advances in Neural Information Processing
Systems, 34:13444–13457, 2021b.
Wang, X., Yu, F., Dou, Z.-Y., Darrell, T., and Gonzalez,
J. E. Skipnet: Learning dynamic routing in convolutional Zhang, S., Gong, C., and Choi, E. Knowing more about
networks. In Proceedings of the European Conference on questions can help: Improving calibration in question
Computer Vision (ECCV), pp. 409–424, 2018b. answering. arXiv preprint arXiv:2106.01494, 2021c.
Williams, A., Nangia, N., and Bowman, S. R. A broad- Zhang, S., Gong, C., and Choi, E. Learning with different
coverage challenge corpus for sentence understanding amounts of annotation: From zero to many labels. arXiv
through inference. In NAACL-HLT, 2017a. preprint arXiv:2109.04408, 2021d.
Williams, A., Nangia, N., and Bowman, S. R. A broad- Zhang, S., Gong, C., and Liu, X. Passage-mask: A learnable
coverage challenge corpus for sentence understanding regularization strategy for retriever-reader models. arXiv
through inference. arXiv preprint arXiv:1704.05426, preprint arXiv:2211.00915, 2022a.
2017b.
Zhang, S., Gong, C., Liu, X., He, P., Chen, W., and Zhou,
Wu, Z., Nagarajan, T., Kumar, A., Rennie, S., Davis, L. S., M. Allsh: Active learning guided by local sensitivity and
Grauman, K., and Feris, R. Blockdrop: Dynamic infer- hardness. arXiv preprint arXiv:2205.04980, 2022b.
ence paths in residual networks. In Proceedings of the
IEEE conference on computer vision and pattern recogni- Zhang, S., Wu, L., Gong, C., and Liu, X. Language recti-
tion, pp. 8817–8826, 2018. fied flow: Advancing diffusion language generation with
probabilistic flows. arXiv preprint arXiv:2403.16995,
Xin, J., Tang, R., Lee, J., Yu, Y., and Lin, J. Deebert: 2024.
Dynamic early exiting for accelerating bert inference.
arXiv preprint arXiv:2004.12993, 2020. Zhou, W., Xu, C., Ge, T., McAuley, J., Xu, K., and Wei,
F. Bert loses patience: Fast and robust inference with
Xin, J., Tang, R., Yu, Y., and Lin, J. Berxit: Early exiting for early exit. Advances in Neural Information Processing
bert with better fine-tuning and extension to regression. Systems, 33:18330–18341, 2020.
In Proceedings of the 16th conference of the European
chapter of the association for computational linguistics:
Main Volume, pp. 91–104, 2021.
Yang, S., Feng, Y., Zhang, S., and Zhou, M. Regularizing a
model-based policy stationary distribution to stabilize of-
fline reinforcement learning. In International Conference
on Machine Learning, pp. 24980–25006. PMLR, 2022a.
Yang, S., Zhang, S., Feng, Y., and Zhou, M. A unified
framework for alternating offline model training and pol-
icy learning. Advances in Neural Information Processing
Systems, 35:17216–17232, 2022b.
Yang, S., Zhang, S., Xia, C., Feng, Y., Xiong, C., and
Zhou, M. Preference-grounded token-level guidance for
language model fine-tuning. Advances in Neural Infor-
mation Processing Systems, 36, 2024.
11
Switchable Decision: Dynamic Neural Generation Networks
A. Experimental details a set of Wikipedia articles. The answers, given the ques-
tions, are text span from the given reading passage. The
A.1. Full Results With Error Bar SQuAD 1.1 contains around 100,000 question-answer pairs
We report the full results of our method with the error bar for on about 500 articles. The SQuAD v2.0 dataset includes
summarization and question answering in Table 12 and 14, unanswerable questions about the same paragraphs.
respectively. The full result of classification is demonstrated
in Table 13. Classification. GLUE (Wang et al., 2018a; Zhang et al.,
2022b) comprises a collection of text classification tasks
CNN/DailyMail XSum
Model
Lead-3
R1 ↑
40.42
R2 ↑
17.62
RL ↑
36.67
FLOPs (%) ↓
-
R1 ↑
16.30
R2 ↑
1.60
RL ↑
11.95
FLOPs (%) ↓
-
meant to test general language understanding abilities. We
UniLM
BERTSUM
43.33
42.13
20.21
19.60
40.51
39.18
-
-
-
38.81
-
16.50
-
31.27
-
- adopt the three datasets for our experiments: natural lan-
BART 44.16 21.28 40.90 100 45.14 22.27 37.25 100
Ours large 44.31±0.1 21.18±0.2 41.01±0.2 61.1 45.20±0.1 22.16±0.2 37.30±0.2 81.9 guage inference (MNLI (Williams et al., 2017a) and RTE
(Dagan et al., 2005)) and sentiment analysis (SST-2 (Socher
Table 12. Full results on CNN/DailyMail and XSum. ROUGE
et al., 2013)).
is reported for each model. ‘BART’ represents the BART large
model.
A.3. Experimental Settings
Model
MNLI RTE SST For summarization, we follow the setting in (Lewis et al.,
m/mm ↑ FLOPs (%) ↓ Acc ↑ FLOPs (%) ↓ Acc ↑ FLOPs (%) ↓
BERT 86.6/- - 70.4 - 93.2 - 2019) and initialize our models with the pretrained BART
UniLM 87.0/85.9 - 70.9 - 94.5 -
RoBERTa 90.2/90.2 - 86.6 - 96.4 - large checkpoint. The checkpoint is from the Fairseq library
BART 89.9/90.1 100 87.0 100 96.6 100 3
Ours 89.7±0.2/90.0±0.3 82.4 87.2±0.1 83.6 96.6±0.2 80.7 . T5 (Raffel et al., 2020) is also used in Section 6. We adopt
the T5 base from the HuggingFace Transformer library4 .
Table 13. Full performance on GLUE. We report the accuracy of Following Lewis et al. (2019) , the Adam optimizer (Kingma
each dataset. All language models here are large size. ‘m/mm’ and & Ba, 2014; Liu et al., 2021; Zhang et al., 2024) is utilized
‘Acc’ denotes accuracy on matched/mismatched version MNLI and for optimizing the model parameter with the learning rate
accuracy, respectively. 3 × 10−5 . The training step is 50k and the warmup step is
500. Both dropout and attention dropout are set as 0.1. For
SQuAD 1.1 SQuAD 2.0
classification, the detailed training settings are presented in
Model
EM/F1 ↑ FLOPs (%) ↓ EM/F1 ↑ FLOPs (%) ↓ Table 15.
BERT 84.1/90.9 - 79.0/81.8 -
UniLM -/- - 80.5/83.4 -
RoBERTa 88.9/ 94.6 - 86.5/ 89.4 - Model MNLI RTE SST-2
BART 88.8/ 94.6 100 86.1/89.2 100
Ours 88.7±0.3/94.5±0.4 80.5 86.0±0.3/89.3±0.3 83.3 NC 3 2 2
LR 5 × 10−6 1 × 10−5 5 × 10−6
Table 14. Full results across different strategies on SQuAD v1.1 BSZ 128 32 128
and v2.0. Answers are text spans extracted from a given document TS 30,968 1,018 5,233
context. WS 1,858 61 314
Table 15. Experiment setting for MNLI, RTE, and SST-2 (LR:
A.2. Experimental Datasets learning rate, BSZ: batch size, NC: number of classes, TS: total
number of training steps, WS: warm-up steps).
Summarization. CNN/DailyMail contains news articles
and associated highlights as summaries. Following the
standard splits from Hermann et al. (2015) for training,
Data ROUGE FLOPs (%)
validation, and testing, we have 90,266/1,220/1,093 CNN BART 44.16/21.28/40.90 100
documents and 196,961/12,148/10,397 DailyMail docu- + Ours 44.31/21.18/41.01 61.1
ments, respectively. The sentence is split by using the Stan- GPT-2 37.55/15.53/25.81 100
+ Ours 37.76/15.68/25.93 74.5
ford CoreNLP toolkit (Manning et al., 2014). For XSum T5 42.05/20.34/39.40 100
(Narayan et al., 2018), summaries are professionally writ- + Ours 41.98/20.38/39.61 74.5
ten by the authors of the documents. We also use the pre- LLaMA -/-/46.68 100
+ Ours -/-/46.73 77.6
processing and data splits from (Narayan et al., 2018; Yang
et al., 2024). Table 16. The proposed method for different generation models on
CNN/DailyMail.
Question Answering. Stanford Question Answering 3
https://github.com/facebookresearch/
Dataset (SQuAD) (Rajpurkar et al., 2016; 2018; Zhang fairseq/tree/main/examples/bart
et al., 2021c; 2022a) is an extractive question answering 4
https://github.com/huggingface/
task, consisting of questions posed by crowdworkers on transformers
12
Switchable Decision: Dynamic Neural Generation Networks
13