继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

深度强化学习基础:Actor-Critic模型解析,附Pytorch完整代码

斯蒂芬大帝
关注TA
已关注
手记 257
粉丝 7
获赞 21

本文将深入解析深度强化学习中的关键组件——Actor-Critic(行动者评论家)算法,并通过实践示例展示如何使用Pytorch实现一个完整的Actor-Critic模型。我们将利用Pytorch库操作,具体展示策略网络(PolicyNet)与价值网络(ValueNet)的构造、以及如何实现模型的更新和训练过程。最终,我们将基于OpenAI Gym的CartPole-v1环境,演示如何将Actor-Critic算法应用于实际问题中,并展示学习曲线以及每回合的回报展示。

1. 算法原理与推导

Actor-Critic算法是强化学习中的一种集成策略,结合了策略迭代与价值迭代的优点。它将强化学习的问题分解为两个部分:

1.1 行动者(Actor)

  • 角色:负责选择动作。
  • 目标:最大化累积奖励。
  • 输出:动作的概率分布或连续动作值。

1.2 评论家(Critic)

  • 角色:评估当前策略的好坏,给出动作的价值。
  • 目标:优化行动者的行为,通过评估动作价值来指导策略更新。
  • 输出:状态价值或动作价值。

1.3 组合

Actor-Critic算法通过评论家的反馈来更新行动者的学习。评论家的输出(状态价值或动作价值)驱动了行动者模型的更新,从而调整策略以优化累积奖励。

2. 公式推导与关键公式

2.1 策略优化目标函数

策略优化的目标是最大化累积奖励,通过优化策略函数$\pi(a|s)$,使得$E[R] = E[\sum_{t=0}^{\infty} \gamma^t r_t | \pi]$最大,其中$r_t$是时间步$t$的奖励,$\gamma$是折扣因子。

2.2 评论家损失函数

评论家网络通常采用状态价值函数或动作价值函数的形式。在Actor-Critic中,我们通常关注状态价值函数$V(s)$,其损失函数可以表示为TD误差的平方均值:

$$
L_{critic} = E[(y - V(s))^2]
$$

其中$y = r + \gamma V(s')$,$s'$是下一个状态。

2.3 Actor损失函数

Actor网络的目标是通过最大化策略的对数似然来优化策略参数,即:

$$
\max{\theta} \sum{s,a,r,s'} \pi(a|s)^{\pi} (r + \gamma V(s') - V(s))
$$

或更简洁地:

$$
\min{\theta} -\sum{s,a,r,s'} \pi(a|s) (r + \gamma V(s') - V(s))
$$

2.4 关键公式

关键公式包括:

  • 策略的参数更新$\theta = \theta - \alpha \nabla{\theta} L{actor}$
  • 评论家的参数更新$\phi = \phi - \beta \nabla{\phi} L{critic}$

其中$\alpha$和$\beta$分别是Actor和Critic的学习率。

3. Pytorch实现

3.1 定义策略网络(PolicyNet)

import torch
from torch import nn

class PolicyNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.softmax(self.fc2(x), dim=1)
        return x

3.2 定义价值网络(ValueNet)

class ValueNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

3.3 实现ActorCritic

class ActorCritic:
    def __init__(self, policy_net, value_net, gamma):
        self.actor = policy_net
        self.critic = value_net
        self.gamma = gamma
        self.optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=1e-3)
        self.optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=1e-2)

    def take_action(self, state):
        state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
        prob = self.actor(state)
        action_dist = torch.distributions.Categorical(probs=prob)
        action = action_dist.sample().item()
        return action

    def update(self, states, actions, rewards, next_states, dones):
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_error = td_target - self.critic(states)
        advantage = td_error.detach()
        log_probs = torch.log(self.actor(states).gather(1, torch.tensor(actions, dtype=torch.long).unsqueeze(1)))
        policy_loss = -torch.mean(log_probs * advantage)
        value_loss = 0.5 * torch.mean(torch.pow(td_error, 2))
        self.optimizer_actor.zero_grad()
        self.optimizer_critic.zero_grad()
        policy_loss.backward()
        value_loss.backward()
        self.optimizer_actor.step()
        self.optimizer_critic.step()

3.4 使用环境与训练

CartPole-v1环境编写函数,设置参数、训练,并可视化结果。

import gym
import matplotlib.pyplot as plt

env = gym.make('CartPole-v1')
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n

ac = ActorCritic(PolicyNet(n_states, 24, n_actions), ValueNet(n_states, 24), gamma=0.99)
returns = []

for episode in range(100):
    state = env.reset()[0]
    total_reward = 0
    transitions = {'states': [], 'actions': [], 'rewards': [], 'next_states': [], 'dones': []}
    while True:
        action = ac.take_action(state)
        next_state, reward, done, _ = env.step(action)
        transitions['states'].append(state)
        transitions['actions'].append(action)
        transitions['rewards'].append(reward)
        transitions['next_states'].append(next_state)
        transitions['dones'].append(done)
        state = next_state
        total_reward += reward
        if done:
            break
    ac.update(transitions['states'], transitions['actions'], transitions['rewards'],
              transitions['next_states'], transitions['dones'])
    returns.append(total_reward)

plt.figure()
plt.plot(returns)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Learning Curve')
plt.show()

通过上面的代码,我们构建了一个基本的Actor-Critic模型,并将其应用于CartPole-v1环境中。模型在训练过程中能够学习到如何控制小车以保持杆子竖直,同时绘制的曲线展示了学习过程中累积奖励的变化趋势。

打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP