Skip to content

Commit 40ffbbb

Browse files
Fix PPO ratio clamping flag
1 parent 9752e5a commit 40ffbbb

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

amp_rsl_rl/algorithms/amp_ppo.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class AMP_PPO:
5959
Maximum gradient norm for clipping gradients during backpropagation.
6060
use_clipped_value_loss : bool, default=True
6161
Flag indicating whether to use a clipped value loss, as in the original PPO implementation.
62-
use_smooth_clamping : bool, default=False
63-
Flag indicating whether to use exponential clamping on the value loos.
62+
use_smooth_ratio_clipping : bool, default=False
63+
Flag indicating whether to apply smooth (exponential) clipping to the PPO policy ratio.
6464
schedule : str, default="fixed"
6565
Learning rate schedule mode ("fixed" or "adaptive" based on KL divergence).
6666
desired_kl : float, default=0.01
@@ -92,7 +92,7 @@ def __init__(
9292
schedule: str = "fixed",
9393
desired_kl: float = 0.01,
9494
amp_replay_buffer_size: int = 100000,
95-
use_smooth_clamping: bool = False,
95+
use_smooth_ratio_clipping: bool = False,
9696
device: str = "cpu",
9797
) -> None:
9898
# Set device and learning hyperparameters
@@ -149,7 +149,7 @@ def __init__(
149149
self.lam: float = lam
150150
self.max_grad_norm: float = max_grad_norm
151151
self.use_clipped_value_loss: bool = use_clipped_value_loss
152-
self.use_smooth_clamped_loss = use_smooth_clamped_loss
152+
self.use_smooth_ratio_clipping: bool = use_smooth_ratio_clipping
153153

154154
def init_storage(
155155
self,
@@ -460,7 +460,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
460460
min_ = 1.0 - self.clip_param
461461
max_ = 1.0 + self.clip_param
462462

463-
if self.use_smooth_clamping:
463+
# Smooth clamping for the ratio if enabled.
464+
if self.use_smooth_ratio_clipping:
464465
clipped_ratio = (
465466
1
466467
/ (1 + torch.exp((-(ratio - min_) / (max_ - min_) + 0.5) * 4))

0 commit comments

Comments
 (0)