Skip to content

Commit 9752e5a

Browse files
authored
Add exponential policy ratio clamping (#13)
1 parent 7749b28 commit 9752e5a

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

amp_rsl_rl/algorithms/amp_ppo.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class AMP_PPO:
5959
Maximum gradient norm for clipping gradients during backpropagation.
6060
use_clipped_value_loss : bool, default=True
6161
Flag indicating whether to use a clipped value loss, as in the original PPO implementation.
62+
use_smooth_clamping : bool, default=False
63+
Flag indicating whether to use exponential clamping on the value loos.
6264
schedule : str, default="fixed"
6365
Learning rate schedule mode ("fixed" or "adaptive" based on KL divergence).
6466
desired_kl : float, default=0.01
@@ -90,6 +92,7 @@ def __init__(
9092
schedule: str = "fixed",
9193
desired_kl: float = 0.01,
9294
amp_replay_buffer_size: int = 100000,
95+
use_smooth_clamping: bool = False,
9396
device: str = "cpu",
9497
) -> None:
9598
# Set device and learning hyperparameters
@@ -146,6 +149,7 @@ def __init__(
146149
self.lam: float = lam
147150
self.max_grad_norm: float = max_grad_norm
148151
self.use_clipped_value_loss: bool = use_clipped_value_loss
152+
self.use_smooth_clamped_loss = use_smooth_clamped_loss
149153

150154
def init_storage(
151155
self,
@@ -452,10 +456,22 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
452456
ratio = torch.exp(
453457
actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)
454458
)
459+
460+
min_ = 1.0 - self.clip_param
461+
max_ = 1.0 + self.clip_param
462+
463+
if self.use_smooth_clamping:
464+
clipped_ratio = (
465+
1
466+
/ (1 + torch.exp((-(ratio - min_) / (max_ - min_) + 0.5) * 4))
467+
* (max_ - min_)
468+
+ min_
469+
)
470+
else:
471+
clipped_ratio = torch.clamp(ratio, min_, max_)
472+
455473
surrogate = -torch.squeeze(advantages_batch) * ratio
456-
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
457-
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
458-
)
474+
surrogate_clipped = -torch.squeeze(advantages_batch) * clipped_ratio
459475
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
460476

461477
# Compute the value function loss.

0 commit comments

Comments
 (0)