Improving Inference Efficiency through Pruning of Parallel Reasoning Chains

A comprehensive analysis of adaptive computation allocation strategies for inference optimization

Vijay KumaravelDavid BaiBalaji Kumaravel
May 2024

Abstract

Inference time scaling has proven extremely effective at improving LLM performance in domains with a generator-verifier gap, where generating candidate solutions is much harder than verifying correctness. Several popular methodologies for scaling inference compute have been explored, with many widely used approaches involving Reinforcement Learning to elicit long Chains-Of-Thought for self-correction, as well as generating multiple candidate solutions and selecting the most correct one (known as best-of-n). Combining these methodologies has proven highly effective, boosting key benchmark results in competitive coding (IOI for o3) and mathematics (Frontier Math, AIME).

This paper explores a more inference-efficient approach to scaling best-of-n for reasoning models through parallel reasoning, by pruning reasoning chains early when they don't contribute to candidate solution diversity. Our experiments on the AIME competition math benchmark demonstrate that our method achieves equivalent pass@50 performance by pruning 40 reasoning chains after only 300 tokens, decoding just 10 reasoning chains to completion.

Model

For our experiments, we selected DeepSeek-R1-Distill-Llama-70B for two key reasons:

  1. It is a distillation of the full DeepSeek-R1-671B, offering similar performance on reasoning benchmarks while being much more efficient to serve, making it a de-facto cost-effective reasoning model in industry.
  2. It allowed for more extensive data collection under the constraints of our 24-hour hackathon with an 8xH100 node.

Metrics

The primary metric in this work is pass@k, which measures how many questions a model answers correctly in a benchmark given k attempts.

Benchmark

Our experiments focused on the AIME competition math benchmark for several reasons:

  1. Competition mathematics problems elicit long reasoning chains, which was crucial for determining the optimal pruning point and demonstrating maximum compute savings.
  2. Competition math is not completely saturated by our chosen model, providing clearer signals about whether our method preserves performance.
  3. AIME problems have single numerical answers, simplifying extraction and verification for our evaluation pipelines.

Infrastructure

We conducted our experiments on an 8xH100 node, generously provided by CoreWeave and North Flank for 24 hours. We used vLLM as the inference engine to run DeepSeek-R1-Distill-Llama-70B with a temperature of 0.7, top p of 0.9, and a batch size of 50 per prompt. We collected as much data as possible within the time constraints to enhance the confidence in our results.

Pipeline

Our experimental pipeline consisted of the following steps:

  1. Generate 50 reasoning chains per AIME problem using DeepSeek-R1-Distill-Llama-70B.
  2. Based on an initial experimental pass, we found the median token length for AIME problems was 23k, so we set our max_output tokens to 23k for efficient batch inference.
  3. We collected data for all problems in the 2023 AIME part 2 and the 2024 AIME part 1, along with a validation set from 2023 AIME part 1.
  4. We used GPT-4o to extract answers from each response and determined the baseline pass@50 performance. With 50 attempts, the model had at least one correct answer for 26/29 problems (89%).

Distribution of Token Counts in Responses

Distribution of token counts in responses

Distribution of token lengths across model responses, showing a median of 22,043 tokens.

After collecting this data, we proceeded with the following additional steps:

  1. We chunked the reasoning chains into sections of 300 tokens using the Cl100k_base tokenizer.
  2. These chunks were fed into the OpenAI Large v3 embedding model to generate embeddings for clustering.
  3. We conducted a hyperparameter sweep to find the optimal pruning point and number of chains to retain, testing 7 different values (10, 15, 20, 25, 30, 35, 40) and evaluating at every chunk up to the halfway point of the reasoning chain.

Results

The results of our hyperparameter sweep are presented in the table below:

Chunk Index10152025303540
086%86%89%89%86%86%86%
186%86%86%86%89%89%89%
286%89%89%89%89%89%89%
386%86%86%86%89%89%89%
475%79%82%86%86%86%86%
582%82%82%82%82%82%86%
682%82%86%86%86%86%86%
779%82%82%82%82%82%82%
882%82%82%86%86%86%82%
975%75%82%82%75%79%79%

We found that using embeddings from just the first 1-3 chunks of the reasoning chains for clustering and pruning resulted in performance equal to allowing all 50 reasoning chains to decode to completion. While these results are subject to limitations in our benchmark scope, they are promising and warrant further exploration.

Comparison to Baseline

The appropriate baseline for comparison is the pass@10 performance of the model, as the compute requirements are essentially equivalent to our pruned pass@50 approach.

Baseline pass@k Performance vs Pruning Strategy

Baseline pass@k Performance vs Pruning Strategy

Performance comparison on AIME 2023 I and 2024 I, showing optimal compute efficiency at k=10.

Performance Comparison on AIME 2023 II

Performance comparison on AIME 2023 II

Performance comparison on AIME 2023 II, showing optimal compute efficiency at k=10.

Limitations and Future Work

We plan to test this methodology on additional reasoning benchmarks, such as competitive coding and chess puzzles with efficient verifiers. We would also like to implement this method on hardware and measure its performance compared to the baseline via FLOP utilization. Currently, clustering of reasoning chains is done through k-means, but other methods for preserving diverse reasoning chains could include using BLEU, ROUGE, and BLEURT scores.