Chat
Ask me anything
Ithy Logo

使用 PyTorch 和 PPO 算法训练 Gymnasium 的 CartPole-v1 环境的完整指南

Frontiers | Reinforcement Learning With Low-Complexity Liquid State ...

简介

强化学习(Reinforcement Learning, RL)是一种使代理通过与环境互动来学习最佳策略的方法。在众多强化学习算法中,近端策略优化(Proximal Policy Optimization, PPO)因其简单性和高效性而受到广泛应用。本指南将详细介绍如何使用 Python 和 PyTorch 库,通过 PPO 算法对 Gymnasium 的 CartPole-v1 环境进行训练。

安装所需库

首先,确保安装了必要的 Python 库。可以使用以下命令安装:

pip install gymnasium torch numpy

导入必要的库

在开始编写代码之前,需要导入以下库:

import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

定义超参数

为了便于调整模型性能,我们定义了一些超参数:

# 超参数
NUM_EPISODES = 1000       # 训练回合数
MAX_STEPS = 200           # 每回合最大步数
GAMMA = 0.99              # 折扣因子
LAMBDA = 0.95             # GAE参数
CLIP_EPS = 0.2            # PPO剪切参数
LR = 0.001                # 学习率
UPDATE_EPOCHS = 10        # 更新次数
BATCH_SIZE = 64           # 批量大小

定义策略网络(Policy Network)

策略网络负责根据当前状态输出每个可能动作的概率分布。我们使用一个简单的全连接神经网络来实现策略网络。

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_head = nn.Linear(128, action_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        action_probs = torch.softmax(self.action_head(x), dim=-1)
        return action_probs

定义价值网络(Value Network)

价值网络用于估计当前状态的价值,以帮助计算优势函数。结构类似于策略网络,但输出为单一的价值估计。

class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.value_head = nn.Linear(128, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        state_value = self.value_head(x)
        return state_value

定义 PPO 智能体

PPO 智能体包含策略网络和价值网络,并实现了选择动作、存储轨迹、计算优势以及更新网络的功能。

class PPOAgent:
    def __init__(self, state_dim, action_dim, lr, gamma, lambda_, clip_eps):
        self.gamma = gamma
        self.lambda_ = lambda_
        self.clip_eps = clip_eps
        
        self.policy_net = PolicyNetwork(state_dim, action_dim)
        self.value_net = ValueNetwork(state_dim)
        
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
        
        self.memory = []
    
    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        action_probs = self.policy_net(state)
        m = Categorical(action_probs)
        action = m.sample()
        return action.item(), m.log_prob(action)
    
    def store_transition(self, transition):
        self.memory.append(transition)
    
    def compute_advantages(self, rewards, dones, values, next_value):
        advantages = []
        gae = 0
        values = values + [next_value]
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.lambda_ * (1 - dones[step]) * gae
            advantages.insert(0, gae)
        returns = [adv + val for adv, val in zip(advantages, values[:-1])]
        advantages = torch.tensor(advantages, dtype=torch.float32)
        returns = torch.tensor(returns, dtype=torch.float32)
        return advantages, returns
    
    def update(self):
        states, actions, log_probs, rewards, dones = zip(*self.memory)
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        old_log_probs = torch.stack(log_probs).detach()
        rewards = list(rewards)
        dones = list(dones)
        
        with torch.no_grad():
            values = self.value_net(states).squeeze().tolist()
            next_state = states[-1]
            next_value = self.value_net(next_state).item()
        
        advantages, returns = self.compute_advantages(rewards, dones, values, next_value)
        
        for _ in range(UPDATE_EPOCHS):
            # 采样数据
            for index in BatchSampler(SubsetRandomSampler(range(len(self.memory))), BATCH_SIZE, False):
                sampled_states = states[index]
                sampled_actions = actions[index]
                sampled_old_log_probs = old_log_probs[index]
                sampled_advantages = advantages[index]
                sampled_returns = returns[index]
                
                # 计算新的动作概率
                action_probs = self.policy_net(sampled_states)
                m = Categorical(action_probs)
                new_log_probs = m.log_prob(sampled_actions)
                
                # 计算比率
                ratios = torch.exp(new_log_probs - sampled_old_log_probs)
                
                # 计算损失
                surr1 = ratios * sampled_advantages
                surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * sampled_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # 价值网络损失
                value_preds = self.value_net(sampled_states).squeeze()
                value_loss = nn.MSELoss()(value_preds, sampled_returns)
                
                # 总损失
                total_loss = policy_loss + 0.5 * value_loss
                
                # 更新策略网络
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()
                
                # 更新价值网络
                self.value_optimizer.zero_grad()
                value_loss.backward()
                self.value_optimizer.step()
        
        # 清空记忆
        self.memory = []

训练过程

下面的代码展示了如何使用 PPO 智能体在 CartPole-v1 环境中进行训练:

def train():
    env = gym.make('CartPole-v1')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = PPOAgent(state_dim, action_dim, LR, GAMMA, LAMBDA, CLIP_EPS)
    
    for episode in range(1, NUM_EPISODES + 1):
        state, _ = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            action, log_prob = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            agent.store_transition((state, action, log_prob, reward, done))
            state = next_state
            total_reward += reward
            
            if done:
                agent.update()
                print(f"Episode {episode}, Total Reward: {total_reward}")
                break
    
    env.close()
    
if __name__ == "__main__":
    train()

关键步骤解释

  1. 环境初始化: 使用 gym.make('CartPole-v1') 创建 CartPole-v1 环境。
  2. 智能体初始化: 实例化 PPOAgent,传入状态维度、动作维度、学习率、折扣因子等超参数。
  3. 选择动作: 通过智能体的 select_action 方法,根据当前状态选择动作,并获取其对数概率。
  4. 与环境交互: 执行动作后,获得下一个状态、奖励和是否终止的信号。
  5. 存储轨迹: 将当前的状态、动作、对数概率、奖励和终止信号存储到智能体的记忆中。
  6. 更新智能体: 在每回合结束后,调用智能体的 update 方法,使用存储的轨迹数据更新策略网络和价值网络。
  7. 打印进度: 输出当前回合的总奖励,便于监控训练进展。

优势函数与广义优势估计(GAE)

优势函数是评估一个动作相对于平均水平的好坏程度。广义优势估计(Generalized Advantage Estimation, GAE)是一种平衡偏差和方差的方法,用于更准确地估计优势函数。

在 PPO 算法中,优势函数用于计算策略更新的目标,从而使得更新过程中更加稳定和高效。

策略更新与裁剪

PPO 算法通过引入裁剪机制(Clipping)来限制策略更新的幅度,防止策略在更新过程中发生剧烈变化,从而保持训练的稳定性。

# 计算比率
ratios = torch.exp(new_log_probs - old_log_probs)

# 计算损失
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()

通过取两者中的最小值,PPO 保证了策略更新的变化不会超过设定的裁剪范围 clip_eps,从而防止了策略的过度更新。

总结

本指南详细介绍了如何使用 Python 和 PyTorch 库,通过 PPO 算法对 Gymnasium 的 CartPole-v1 环境进行训练。关键步骤包括定义策略网络和价值网络、实现 PPO 智能体、计算优势函数以及通过裁剪机制稳定策略更新。通过不断训练,智能体将逐渐学会在 CartPole 环境中保持杆子的平衡,最大化累计奖励。

参考资料

进一步优化与扩展

在实际应用中,可以通过以下方法进一步优化和扩展 PPO 算法的性能:

  • 使用经验回放: 存储更多的轨迹数据,以便在策略更新时使用多样化的数据源。
  • 网络架构调整: 增加神经网络的深度或宽度,以提升模型的表达能力。
  • 奖励归一化: 对奖励进行归一化,有助于加速训练过程并提高稳定性。
  • 多线程或并行环境: 使用多个环境并行采样,以提高数据采集效率。
  • 调节超参数: 根据具体任务需求,调整学习率、折扣因子、GAE参数等超参数,以获得更好的性能。

结语

强化学习,特别是 PPO 算法,是一个功能强大的工具,适用于各种控制任务。通过本指南的学习和实践,读者应能够理解 PPO 算法的基本原理,并能够在 Gymnasium 环境中实现和训练自己的强化学习代理。持续的实验和优化将有助于进一步提升模型的性能和适应性。


Last updated January 3, 2025
Ask Ithy AI
Download Article
Delete Article