TRL(Transformer Reinforcement Learning)은 HuggingFace에서 개발한 최첨단 라이브러리로, 고급 기술을 사용하여 기초 모델을 후속 훈련하는 데 특화되어 있습니다. 이 라이브러리는 감독 미세 조정(SFT), 근접 정책 최적화(PPO) 및 직접 선호도 최적화(DPO)와 같은 고급 기술을 사용하여 기초 모델을 후속 훈련하도록 설계되었습니다.
TRL은 다양한 접근하기 쉬운 트레이너를 제공합니다.
코드를 작성하지 않고도 모델 미세 조정을 수행할 수 있는 간단한 CLI 인터페이스를 제공합니다.
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
GRPO 알고리즘은 PPO보다 메모리를 절약하며, Deepseek AI의 R1 모델 훈련에 사용되었습니다.
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
DPO는 Llama 3와 같은 모델을 후속 훈련하는 데 사용된 인기 있는 알고리즘입니다.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward")
trainer = RewardTrainer(
args=training_args,
model=model,
processing_class=tokenizer,
train_dataset=dataset,
)
trainer.train()
pip install trl
pip install git+https://github.com/huggingface/trl.git
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
TRL은 강력하고 사용하기 쉬운 라이브러리로, 연구원과 개발자에게 대규모 언어 모델을 훈련하고 최적화하는 데 필요한 완벽한 도구 세트를 제공합니다. 최신 강화 학습 기술과 HuggingFace 생태계의 장점을 결합하여 고품질 모델 훈련을 더욱 접근 가능하고 효율적으로 만듭니다. 학술 연구든 산업 응용이든 TRL은 Transformer 모델 후속 훈련을 위한 이상적인 선택입니다.