critic-model
深度学习中的Critic模型与价值函数优化
元数据:
- 分类:深度学习
- 标签:Critic模型, 价值函数, RLHF, Actor-Critic
- 日期:2025年4月12日
核心观点总结
Critic模型用于预测期望总收益
重点段落
-
Critic模型的初始化与更新
Critic模型可以通过多种方式初始化,例如与Actor共享部分参数或从Reward Model初始化。价值头层是一个简单的线性层,用于将原始输出映射为单一值。Critic模型通过TD-error更新,当前状态的预估收益对齐即时奖励加上下一步状态预估收益。 -
Critic Loss计算
Critic损失函数为:实际代码中使用GAE+value=returns,然后使用critic当前的输出values对齐returns。
-
代码实现中的关键步骤
在代码实现中,ValueLoss类用于计算损失。通过裁剪操作来对values进行处理,确保模型的稳定性。
技术术语转述
- TD-error:一种用于更新价值函数的方法,通过比较预期收益与实际收益来调整。
- GAE:广义优势估计,用于提高策略梯度方法的效率。
- Value Head层:一个简单的线性层,将复杂的输出映射为单一数值。
操作步骤
- ✅ 初始化Critic模型,可以选择与Actor共享参数或从Reward Model初始化。
- ⚠ 使用TD-error更新Critic模型,确保对当前状态的预估收益进行校正。
- ❗ 在代码实现中,使用裁剪操作来稳定values的计算。
常见错误
⚠ 在计算Critic损失时,忘记应用裁剪操作可能导致不稳定的训练过程。
代码示例
class ValueLoss(nn.Module):
def __init__(self, clip_eps: float = None) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self, values: torch.Tensor, old_values: torch.Tensor, returns: torch.Tensor, action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.clip_eps is not None:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2)
else:
loss = (values - returns) ** 2
loss = masked_mean(loss, action_mask, dim=-1).mean()
return 0.5 * loss
💡启发点
在训练过程中,结合GAE和value能够有效提高Critic模型的稳定性和准确性。
行动清单
来源:原文内容基于深度学习与强化学习结合的技术文档。