Large language models (LLMs) are increasingly vital in natural language processing and understanding, thanks to their effectiveness and versatility. However, their deployment is resource-intensive. NVIDIA researchers have demonstrated that combining structured weight pruning with knowledge distillation can efficiently create smaller, cost-effective language models, according to NVIDIA Technical Blog.
Pruning and Distillation
Pruning reduces the model size by dropping layers (depth pruning) or neurons, attention heads, and embedding channels (width pruning). This process is often followed by retraining to recover accuracy. Model distillation transfers knowledge from a larger, complex model (teacher) to a smaller, simpler model (student), aiming to retain much of the original model’s predictive power while being faster and less resource-intensive.
Classical Knowledge Distillation vs. SDG Finetuning
Distillation can be categorized into two main styles:
- SDG Finetuning: Uses synthetic data generated from a larger teacher model to fine-tune a smaller, pretrained student model, mimicking the final token prediction.
- Classical Knowledge Distillation: The student mimics the logits and other intermediate states of the teacher on the training dataset, providing richer feedback and improving training accuracy and efficiency.
These methods are complementary, and NVIDIA’s approach focuses on classical knowledge distillation.
Pruning and Distillation Procedure
NVIDIA’s procedure involves:
- Starting with a 15B model, estimating the importance of each component, and trimming it to an 8B model.
- Applying light retraining using model distillation, with the original model as the teacher and the pruned model as the student.
- Further trimming and distilling the small 8B model to a 4B model.
This iterative approach ensures that the output model of one stage serves as the input model for the next, optimizing resource efficiency.
Importance Analysis
Understanding which parts of the model are crucial is essential for pruning. NVIDIA proposes an activation-based importance estimation strategy, which is cost-effective and straightforward compared to gradient-based strategies.
Retraining with Classical Knowledge Distillation
Retraining involves minimizing a combination of embedding output loss, logit loss, and transformer encoder-specific losses, ensuring the smaller model retains much of the original model’s accuracy.
Best Practices for Pruning and Distillation
NVIDIA’s extensive studies have identified several best practices:
- Sizing: Train the largest model first, then prune and distill iteratively.
- Pruning: Prefer width pruning over depth pruning for models ≤ 15B.
- Retraining: Use distillation loss exclusively, combining logit and intermediate state distillation when necessary.
Llama-3.1-Minitron: Applying Best Practices
NVIDIA applied these practices to the Llama 3.1 8B model, resulting in the efficient Llama-3.1-Minitron 4B model. This model performs favorably against state-of-the-art open-source models of similar size, such as Phi-2 2.7B and Gemma2 2.6B.
Teacher Fine-Tuning
Fine-tuning the unpruned 8B model on a specific dataset corrects for distribution shifts, ensuring optimal guidance during distillation.
Depth-Only and Width-Only Pruning
For depth-only pruning, NVIDIA pruned 16 layers from the 8B model, focusing on layers that least affected downstream task performance. For width-only pruning, they reduced the MLP intermediate dimension and hidden size, followed by retraining the attention headcount and number of layers.
Accuracy and Performance Benchmarks
Table 1 compares the performance of Llama-3.1-Minitron 4B variants to other models, showing significant improvements in accuracy and resource efficiency. Performance benchmarks indicate that the Llama-3.1-Minitron 4B model achieves an average throughput of ~2.7x compared to the original 8B model.
Conclusion
Combining pruning and classical knowledge distillation offers a cost-effective method to create smaller LLMs with superior accuracy compared to training from scratch. NVIDIA’s Llama-3.1-Minitron 4B model exemplifies this approach, providing a robust solution for efficient language model deployment.
Image source: Shutterstock