简介(未完成)
与ppo对比
PPO | GRPO | |
---|---|---|
价值网络的使用 | 依赖于一个与策略模型大小相当的价值网络(critic model)来估计优势函数(advantage function)。这个价值网络需要在每个时间步对状态进行评估,计算复杂度高,内存占用大。 | 完全摒弃了价值网络,通过组内相对奖励来估计优势函数。 |
奖励计算方式 | 使用广义优势估计(GAE)来计算优势函数,需要对每个动作的即时奖励和未来奖励的折扣总和进行估计。 | 通过采样一组动作并计算它们的奖励值,然后对这些奖励值进行归一化处理,得到相对优势。这种方法更直接,减少了对复杂奖励模型的依赖。 |
策略更新机制 | 通过裁剪概率比(clip operation)来限制策略更新的幅度,确保策略分布的变化在可控范围内。 | 引入了KL散度约束,直接在损失函数中加入KL散度项,从而更精细地控制策略更新的幅度。 |
计算效率 | 由于需要维护和更新价值网络,计算效率较低,尤其是在大规模语言模型中,训练过程可能变得非常缓慢。 | 通过避免价值网络的使用,显著提高了计算效率,降低了内存占用,更适合大规模语言模型的微调。 |
优势 | PPO通过裁剪概率比,能够有效防止策略更新过于剧烈,从而保持训练过程的稳定性。PPO在多种强化学习任务中表现出色,适用于多种类型的环境和任务。 | GRPO通过避免价值网络的使用,显著降低了计算和存储需求,提高了训练效率。通过组内相对奖励的计算,GRPO减少了策略更新的方差,确保了更稳定的学习过程。GRPO引入了KL散度约束,能够更精细地控制策略更新的幅度,保持策略分布的稳定性。 |
局限 | 在大规模语言模型中,PPO需要维护一个与策略模型大小相当的价值网络,导致显著的内存占用和计算代价。PPO的策略更新依赖于单个动作的奖励值,可能导致较高的方差,影响训练的稳定性。 | GRPO需要对每个状态采样一组动作,这在某些情况下可能会增加采样成本。GRPO在某些任务中可能不如PPO表现稳定,尤其是在奖励信号稀疏的情况下。 |
A concrete example of GRPO in action:
Query: “What is 2 + 3?”
Step 1: LLM generates three answers.
1. “5”
2. “6”
3. “2 + 3 = 5”
Step 2: Each answer is scored.
1. “5” → 1 points (correct, no reasoning)
2. “6” → 0 points (incorrect)
3. “2 + 3 = 5” → 2 points (correct, w/ reasoning)
Step 3: Compute avg score for entire group.
Avg score = (1 + 0 + 2) / 3 = 1
Step 4: Compare each answer score to avg.
1. “5” → 0 (same as avg)
2. “6” → -1 (below avg)
3. “2 + 3 = 5” → 1 (above avg)
Step 5: Reinforce LLM to favor higher scores.
1. Favor responses like #3 (positive)
2. Maintain responses like #1 (neutral)
3. Avoid responses like #2 (negative)
This process is repeated, allowing the model to learn and improve over time.
Coding GRPO from Scratch: A Guide to Distributed Implementation with Qwen2.5-1.5B-Instruct
def correctness_reward(prompts, completions, answer, **kwargs):
"""
Assigns a reward based on the correctness of the model's answer.
Explanation:
1. Extracts the content from each completion.
2. Extracts the answer portion from each response using extract_answer_from_model_output.
3. Assigns rewards based on matching criteria:
- 2.0 points for an exact match
- 1.5 points for numeric equivalence (when values match but format differs)
- 0.0 points for incorrect answers
4. Tracks completion lengths for analysis.
"""
responses = [completion[0]['content'] for completion in completions]
extracted = [extract_answer_from_model_output(r) for r in responses]
rewards = []
for r, a in zip(extracted, answer):
if r == a: # Exact match case
rewards.append(2.0)
else:
# Try numeric equivalence
r_num = extract_single_number(str(r))
a_num = extract_single_number(str(a))
if r_num is not None and a_num is not None and r_num == a_num:
rewards.append(1.5)
else:
rewards.append(0.0)
# Log completion lengths
completion_lengths = [len(response.split()) for response in responses]
return rewards
def format_reward(completions, **kwargs):
"""
Assigns a reward for adhering to the desired XML format.
Explanation:
1. Extracts the content from each completion.
2. Evaluates format compliance by checking for required XML tags:
- 0.2 points for each tag present (<reasoning>, </reasoning>, <answer>, </answer>)
- Maximum score of 0.8 for perfect format compliance
3. Stores and returns the format compliance scores.
"""
responses = [completion[0]['content'] for completion in completions]
rewards = []
format_scores = []
for response in responses:
score = 0.0
if "<reasoning>" in response: score += 0.2
if "</reasoning>" in response: score += 0.2
if "<answer>" in response: score += 0.2
if "</answer>" in response: score += 0.2
rewards.append(score)
format_scores.append(score)
return rewards
HuggingFace GRPOTrainer继承自Trainer类,在Trainer类中封装了很多的训练逻辑
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
train_dataset=dataset,
)
trainer.train()