Prefer TorchRL primitives over custom RL plumbing. For PPO, do not manually implement the clipped PPO loss, log-prob ratio math, rollout buffers, or GAE unless the user explicitly asks for a from-scratch educational version. Compose:
GymEnv or another TorchRL env wrapperTransformedEnv with Compose(...)TensorDictModule policy parameter networkProbabilisticActorValueOperatorSyncDataCollector or a multi-process collectorReplayBuffer with LazyTensorStorage and SamplerWithoutReplacementGAEClipPPOLossTorchRL modules communicate through TensorDict. Always be explicit about in_keys and out_keys, usually reading "observation" and writing distribution parameters plus "action" and "sample_log_prob".
Use this structure for a standard PPO training function:
from collections import defaultdict
import torch
from torch import nn
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import Compose, DoubleToFloat, ObservationNorm, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import ExplorationType, check_env_specs, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
For a small example, keep hyperparameters modest:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_cells = 256
lr = 3e-4
max_grad_norm = 1.0
frames_per_batch = 1000
total_frames = 50_000
sub_batch_size = 64
num_epochs = 10
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4
Wrap simulator environments through TorchRL:
base_env = GymEnv("InvertedDoublePendulum-v4", device=device)
env = TransformedEnv(
base_env,
Compose(
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(),
StepCounter(),
),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
check_env_specs(env)
Notes:
ObservationNorm for continuous-control examples when the tutorial pattern applies.DoubleToFloat() to convert double observations for neural networks.StepCounter() when logging episode length or evaluating max step count.CartPole-v1, normalization is optional; DoubleToFloat() and StepCounter() are usually enough.For continuous action spaces, create loc and scale, then use TanhNormal:
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
NormalParamExtractor(),
)
policy_params = TensorDictModule(
actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
)
policy_module = ProbabilisticActor(
module=policy_params,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
)
For discrete action spaces, output logits and use OneHotCategorical:
from torchrl.modules import OneHotCategorical
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(env.action_spec.shape[-1], device=device),
)
policy_params = TensorDictModule(
actor_net,
in_keys=["observation"],
out_keys=["logits"],
)
policy_module = ProbabilisticActor(
module=policy_params,
spec=env.action_spec,
in_keys=["logits"],
distribution_class=OneHotCategorical,
return_log_prob=True,
)
Create the critic with ValueOperator:
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
)
Run a quick module sanity check before training:
print(policy_module(env.reset()))
print(value_module(env.reset()))
Use SyncDataCollector for the minimal single-process case:
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
Use a fresh on-policy replay buffer for each collected batch pattern:
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(max_size=frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
Compute advantages with GAE, then feed the same TensorDict data to ClipPPOLoss:
advantage_module = GAE(
gamma=gamma,
lmbda=lmbda,
value_network=value_module,
average_gae=True,
device=device,
)
loss_module = ClipPPOLoss(
actor_network=policy_module,
critic_network=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coeff=entropy_eps,
critic_coeff=1.0,
loss_critic_type="smooth_l1",
)
optim = torch.optim.Adam(loss_module.parameters(), lr)
Version note: some older tutorials or translations use entropy_coef / critic_coef or min / max for distribution bounds. Prefer the installed TorchRL signature. If unsure, inspect locally with inspect.signature(ClipPPOLoss) and inspect.signature(TanhNormal).
Use the collector as the outer loop. Recompute advantage each PPO epoch because the critic changes during optimization.
logs = defaultdict(list)
for i, tensordict_data in enumerate(collector):
for _ in range(num_epochs):
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size).to(device)
loss_vals = loss_module(subdata)
loss = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
loss.backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
optim.step()
optim.zero_grad()
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
logs["step_count"].append(tensordict_data["step_count"].max().item())
if i % 10 == 0:
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
eval_rollout = env.rollout(1000, policy_module)
logs["eval_reward_sum"].append(eval_rollout["next", "reward"].sum().item())
logs["eval_step_count"].append(eval_rollout["step_count"].max().item())
If the distribution does not have a meaningful mean, use ExplorationType.DETERMINISTIC when supported by the actor/distribution.
TensorDict; avoid parallel Python lists for observations, rewards, dones, values, and log probs.ProbabilisticActor(..., return_log_prob=True) so PPO receives "sample_log_prob".action_spec when constructing the actor; it handles bounds and one-hot discrete actions.check_env_specs(env) after transforms when adding or changing an environment..reshape(-1) before adding collected rollout data to the replay buffer..to(device).共 1 个版本