PPO算法

PPO算法详解与应用 | 强化学习策略优化

元数据

核心观点

Proximal Policy Optimization (PPO) 是一种用于强化学习的策略优化算法,通过限制策略更新的变化幅度,提升了算法的稳定性和效率。PPO主要有两个版本:PPO-penalty和PPO-clip,其中PPO-clip因其简化的优化目标而广泛应用。它通过引入重要性采样和策略截断,解决了样本利用效率低的问题。

重点段落

PPO算法的基本原理

PPO基于TRPO的优化目标进行了简化,直接对策略更新进行clip操作,限制更新幅度在一个安全范围内。其优化目标如下:

maxθEsνβ,aπθk(|s)[min(πθk(a|s)πθ(a|s)Aπθk(s,a),clip(πθk(a|s)πθ(a|s),1ϵ,1+ϵ)Aπθk(s,a))]

其中,πθ为Actor的策略,Aπθk为Critic提供的优势函数估计。

重要性采样技术

PPO利用重要性采样提高样本利用效率。假设有一个提议分布q(x),可以将关于目标分布p(x)的期望重写为:

Ep[h(x)]=h(x)p(x)q(x)q(x)dx

这样可以利用之前迭代的策略产生的数据。

策略截断与稳定性

PPO通过clip操作进行策略截断,避免策略更新偏离上一个迭代回合的策略。相比TRPO复杂的KL散度估计,PPO采用固定截断,更为简便。

操作步骤

  1. 策略初始化:设定初始策略参数θ
  2. 采样数据:从当前策略中采样动作和状态。
  3. 计算优势:利用Critic计算优势函数Aπθk(s,a)
  4. 更新策略:根据PPO-clip优化目标更新策略参数。

常见错误

警告:在使用PPO时,需注意clip范围的设置不宜过大,否则会导致策略更新过于激进。

💡启发点

class PolicyNet(torch.nn.Module): 
    def __init__(self, state_dim, hidden_dim, action_dim): 
        super(PolicyNet, self).__init__() 
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 

    def forward(self, x): 
        x = F.relu(self.fc1(x)) 
        return F.softmax(self.fc2(x), dim=1) 

class ValueNet(torch.nn.Module): 
    def __init__(self, state_dim, hidden_dim): 
        super(ValueNet, self).__init__() 
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 
        self.fc2 = torch.nn.Linear(hidden_dim, 1) 

    def forward(self, x): 
        x = F.relu(self.fc1(x)) 
        return self.fc2(x) 

class PPO: 
    ''' PPO算法,采用截断方式 ''' 
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 
                 lmbda, epochs, eps, gamma, device): 
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device) 
        self.critic = ValueNet(state_dim, hidden_dim).to(device) 
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 
                                                lr=actor_lr) 
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 
                                                 lr=critic_lr) 
        self.gamma = gamma 
        self.lmbda = lmbda 
        self.epochs = epochs # 一条序列的数据用来训练轮数 
        self.eps = eps # PPO中截断范围的参数 
        self.device = device 

    def take_action(self, state): 
        state = torch.tensor([state], dtype=torch.float).to(self.device) 
        probs = self.actor(state) 
        action_dist = torch.distributions.Categorical(probs) 
        action = action_dist.sample() 
        return action.item() 

    def update(self, transition_dict): 
        states = torch.tensor(transition_dict['states'], 
                              dtype=torch.float).to(self.device) 
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 
            self.device) 
        rewards = torch.tensor(transition_dict['rewards'], 
                               dtype=torch.float).view(-1, 1).to(self.device) 
        next_states = torch.tensor(transition_dict['next_states'], 
                                   dtype=torch.float).to(self.device) 
        dones = torch.tensor(transition_dict['dones'], 
                             dtype=torch.float).view(-1, 1).to(self.device) 

        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        
        self.actor_optimizer.step() 
        self.critic_optimizer.step()

行动清单

📈趋势预测

随着强化学习在各领域的应用扩展,PPO因其稳定性和高效性,将在更多复杂任务中得到应用。

后续追踪

原始出处:PPO算法和代码