Fine-tuning Llama and Mistral models with Hugging Face TRL is surprisingly less about "teaching" the model and more about "guiding" its existing knowledge.
Let’s see TRL in action. Imagine you have a base Llama 2 7B model and you want it to be better at summarizing technical documentation.
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer
from datasets import load_dataset
# Load base model and tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
# Load a dataset (replace with your actual summarization dataset)
dataset = load_dataset("xsum", split="train")
dataset = dataset.select(range(1000)) # Use a subset for demonstration
# Define training arguments
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./llama-2-7b-summarization",
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
learning_rate=2e-4,
num_train_epochs=1,
logging_steps=10,
save_steps=100,
fp16=True,
)
# Initialize SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="document", # Assuming your dataset has a 'document' field for input
max_seq_length=512,
tokenizer=tokenizer,
args=training_args,
packing=False, # Set to True for more efficient packing if your dataset has short sequences
)
# Start training
trainer.train()
# Save the fine-tuned model
trainer.save_model("./fine-tuned-llama-2-7b-summarization")
This SFTTrainer from TRL (Transformer Reinforcement Learning) is the core. It takes your base LLM and a dataset of examples, then uses supervised learning to adjust the model’s weights. The "reinforcement learning" part comes in later if you use PPO or other RL-based methods for alignment, but for basic fine-tuning, it’s supervised. The dataset_text_field is crucial; it tells TRL which column in your dataset contains the text to be trained on. max_seq_length dictates how much context the model sees per training step. packing=True can significantly speed up training by concatenating multiple short examples into single sequences, but it requires careful handling of attention masks.
The problem TRL solves is making LLMs perform specific tasks better without needing massive, task-agnostic pre-training. Instead of retraining a model from scratch on billions of tokens for summarization, you take a pre-trained model (like Llama 2 or Mistral) and show it a few thousand examples of good summaries. The model already understands language; you’re just teaching it the style and content focus of your target task. The SFTTrainer automates the process of formatting data, feeding it to the model, and updating weights based on the difference between the model’s output and the desired output (your ground truth labels).
Internally, SFTTrainer handles tokenization, batching, and the forward/backward passes. For each example, it predicts the next token based on the preceding context. The loss is calculated by comparing these predictions to the actual next tokens in your training data. Think of it as playing a game where the model tries to guess the next word in your summary, and it gets points for being right. Over many examples, it learns to generate summaries that look like yours. The TrainingArguments control the hyperparameters: learning_rate determines how big the "steps" are when adjusting weights, num_train_epochs is how many times the model sees the entire dataset, and per_device_train_batch_size is how many examples are processed in parallel on your GPU.
The most surprising thing is how little data is often needed for significant task improvement. With a large, capable base model, you can achieve remarkable results on specialized tasks with just a few thousand high-quality examples. This drastically reduces the cost and complexity compared to full pre-training. It’s about leveraging the immense general knowledge already encoded in the base model and nudging it towards your specific domain or style.
The next step after supervised fine-tuning is often aligning the model’s outputs with human preferences using techniques like Reinforcement Learning from Human Feedback (RLHF), which TRL also supports through its PPOTrainer.