← 返回
未分类

TorchRL-PPO

Build reinforcement-learning training code with TorchRL APIs, especially PPO agents and examples using Gym/Gymnasium environments. Use when implementing or reviewing TorchRL code that should rely on TensorDict, GymEnv, TransformedEnv, TensorDictModule, ProbabilisticActor, ValueOperator, SyncDataCollector, ReplayBuffer, GAE, and ClipPPOLoss instead of hand-writing PPO losses, rollout storage, or actor-critic plumbing.
Build reinforcement-learning training code with TorchRL APIs, especially PPO agents and examples using Gym/Gymnasium environments. Use when implementing or reviewing TorchRL code that should rely on TensorDict, GymEnv, TransformedEnv, TensorDictModule, ProbabilisticActor, ValueOperator, SyncDataCollector, ReplayBuffer, GAE, and ClipPPOLoss instead of hand-writing PPO losses, rollout storage, or actor-critic plumbing.
RUBisco
未分类 community v1.0.0 1 版本 100000 Key: 无需
★ 0
Stars
📥 74
下载
💾 0
安装
1
版本
#latest

概述

TorchRL

Core Principle

Prefer TorchRL primitives over custom RL plumbing. For PPO, do not manually implement the clipped PPO loss, log-prob ratio math, rollout buffers, or GAE unless the user explicitly asks for a from-scratch educational version. Compose:

  1. GymEnv or another TorchRL env wrapper
  2. TransformedEnv with Compose(...)
  3. TensorDictModule policy parameter network
  4. ProbabilisticActor
  5. ValueOperator
  6. SyncDataCollector or a multi-process collector
  7. ReplayBuffer with LazyTensorStorage and SamplerWithoutReplacement
  8. GAE
  9. ClipPPOLoss

TorchRL modules communicate through TensorDict. Always be explicit about in_keys and out_keys, usually reading "observation" and writing distribution parameters plus "action" and "sample_log_prob".

PPO Workflow

Use this structure for a standard PPO training function:

from collections import defaultdict

import torch
from torch import nn
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import Compose, DoubleToFloat, ObservationNorm, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import ExplorationType, check_env_specs, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

For a small example, keep hyperparameters modest:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_cells = 256
lr = 3e-4
max_grad_norm = 1.0
frames_per_batch = 1000
total_frames = 50_000
sub_batch_size = 64
num_epochs = 10
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

Environment Pattern

Wrap simulator environments through TorchRL:

base_env = GymEnv("InvertedDoublePendulum-v4", device=device)

env = TransformedEnv(
    base_env,
    Compose(
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
check_env_specs(env)

Notes:

  • Use ObservationNorm for continuous-control examples when the tutorial pattern applies.
  • Use DoubleToFloat() to convert double observations for neural networks.
  • Use StepCounter() when logging episode length or evaluating max step count.
  • For very small discrete environments such as CartPole-v1, normalization is optional; DoubleToFloat() and StepCounter() are usually enough.
  • Gym execution is CPU-bound even when tensors are stored on a device.

Policy And Value Modules

For continuous action spaces, create loc and scale, then use TanhNormal:

actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
    NormalParamExtractor(),
)

policy_params = TensorDictModule(
    actor_net,
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

policy_module = ProbabilisticActor(
    module=policy_params,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec.space.low,
        "high": env.action_spec.space.high,
    },
    return_log_prob=True,
)

For discrete action spaces, output logits and use OneHotCategorical:

from torchrl.modules import OneHotCategorical

actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(env.action_spec.shape[-1], device=device),
)

policy_params = TensorDictModule(
    actor_net,
    in_keys=["observation"],
    out_keys=["logits"],
)

policy_module = ProbabilisticActor(
    module=policy_params,
    spec=env.action_spec,
    in_keys=["logits"],
    distribution_class=OneHotCategorical,
    return_log_prob=True,
)

Create the critic with ValueOperator:

value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)

Run a quick module sanity check before training:

print(policy_module(env.reset()))
print(value_module(env.reset()))

Collector, Buffer, Advantage, Loss

Use SyncDataCollector for the minimal single-process case:

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

Use a fresh on-policy replay buffer for each collected batch pattern:

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

Compute advantages with GAE, then feed the same TensorDict data to ClipPPOLoss:

advantage_module = GAE(
    gamma=gamma,
    lmbda=lmbda,
    value_network=value_module,
    average_gae=True,
    device=device,
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coeff=entropy_eps,
    critic_coeff=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)

Version note: some older tutorials or translations use entropy_coef / critic_coef or min / max for distribution bounds. Prefer the installed TorchRL signature. If unsure, inspect locally with inspect.signature(ClipPPOLoss) and inspect.signature(TanhNormal).

Training Loop

Use the collector as the outer loop. Recompute advantage each PPO epoch because the critic changes during optimization.

logs = defaultdict(list)

for i, tensordict_data in enumerate(collector):
    for _ in range(num_epochs):
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())

        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size).to(device)
            loss_vals = loss_module(subdata)
            loss = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    logs["step_count"].append(tensordict_data["step_count"].max().item())

    if i % 10 == 0:
        with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval_reward_sum"].append(eval_rollout["next", "reward"].sum().item())
            logs["eval_step_count"].append(eval_rollout["step_count"].max().item())

If the distribution does not have a meaningful mean, use ExplorationType.DETERMINISTIC when supported by the actor/distribution.

Implementation Checklist

  • Verify the installed TorchRL API before relying on translated tutorial argument names.
  • Keep all rollout data in TensorDict; avoid parallel Python lists for observations, rewards, dones, values, and log probs.
  • Ensure ProbabilisticActor(..., return_log_prob=True) so PPO receives "sample_log_prob".
  • Use the env action_spec when constructing the actor; it handles bounds and one-hot discrete actions.
  • Call check_env_specs(env) after transforms when adding or changing an environment.
  • Use .reshape(-1) before adding collected rollout data to the replay buffer.
  • Move sampled minibatches back to the training device with .to(device).
  • Close or shut down collectors/environments in long-running scripts when appropriate.

版本历史

共 1 个版本

  • v1.0.0 Initial release 当前
    2026-05-12 16:39 安全 安全

安全检测

腾讯云安全 (Keen)

安全,无风险
查看报告

腾讯云安全 (Sanbu)

安全,无风险
查看报告

🔗 相关推荐

dev-programming

Mcporter

steipete
使用 mcporter CLI 直接列出、配置、认证及调用 MCP 服务器/工具(支持 HTTP 或 stdio),涵盖临时服务器、配置编辑及 CLI/类型生成功能。
★ 196 📥 67,964
dev-programming

YouTube

byungkyu
使用托管OAuth集成YouTube Data API,支持搜索视频、管理播放列表、获取频道数据及评论互动,适用于用户需要时使用此技能。
★ 142 📥 41,893
dev-programming

Github

steipete
使用 `gh` CLI 与 GitHub 交互,通过 `gh issue`、`gh pr`、`gh run` 和 `gh api` 管理议题、PR、CI 运行及高级查询。
★ 681 📥 329,453