TRL (Transformer Reinforcement Learning) est une bibliothèque de pointe développée par HuggingFace, spécialement conçue pour le post-entraînement de modèles de base à l'aide de technologies avancées. Cette bibliothèque est conçue pour le post-entraînement de modèles de base, en utilisant des techniques avancées telles que le fine-tuning supervisé (SFT), l'optimisation de politique proximale (PPO) et l'optimisation directe des préférences (DPO).
TRL propose plusieurs entraîneurs facilement accessibles :
Fournit une interface CLI simple, permettant le fine-tuning des modèles sans écrire de code.
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
L'algorithme GRPO est plus économe en mémoire que PPO et a été utilisé pour entraîner le modèle R1 de Deepseek AI :
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
DPO est un algorithme populaire qui a été utilisé pour le post-entraînement de modèles tels que Llama 3 :
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward")
trainer = RewardTrainer(
args=training_args,
model=model,
processing_class=tokenizer,
train_dataset=dataset,
)
trainer.train()
pip install trl
pip install git+https://github.com/huggingface/trl.git
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
TRL est une bibliothèque puissante et facile à utiliser qui fournit aux chercheurs et aux développeurs un ensemble d'outils complet pour entraîner et optimiser de grands modèles de langage. Elle combine les dernières technologies d'apprentissage par renforcement et les avantages de l'écosystème HuggingFace, rendant l'entraînement de modèles de haute qualité plus accessible et efficace. Que ce soit pour la recherche académique ou les applications industrielles, TRL est le choix idéal pour le post-entraînement des modèles Transformer.