TRL(Transformer Reinforcement Learning)は、HuggingFaceが開発した最先端のライブラリであり、高度な技術を用いて基盤モデルを事後学習するために特化しています。このライブラリは、教師あり微調整(SFT)、近接方策最適化(PPO)、直接選好最適化(DPO)などの先進技術を使用して、基盤モデルを事後学習するように設計されています。
TRLは、アクセスしやすい様々なトレーナーを提供します。
コードを書かずにモデルを微調整できるシンプルなCLIインターフェースを提供します。
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
GRPOアルゴリズムはPPOよりもメモリを節約でき、Deepseek AIのR1モデルの訓練に使用されました。
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
DPOは、Llama 3などのモデルの事後学習に使用された人気のあるアルゴリズムです。
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward")
trainer = RewardTrainer(
args=training_args,
model=model,
processing_class=tokenizer,
train_dataset=dataset,
)
trainer.train()
pip install trl
pip install git+https://github.com/huggingface/trl.git
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
TRLは、研究者と開発者に大規模言語モデルを訓練および最適化するための完全なツールセットを提供する、強力で使いやすいライブラリです。最新の強化学習技術とHuggingFaceエコシステムの利点を組み合わせることで、高品質なモデル訓練がよりアクセスしやすく、効率的になります。学術研究であろうと産業応用であろうと、TRLはTransformerモデルの事後学習を行うための理想的な選択肢です。