@@ -59,6 +59,8 @@ class AMP_PPO:
59
59
Maximum gradient norm for clipping gradients during backpropagation.
60
60
use_clipped_value_loss : bool, default=True
61
61
Flag indicating whether to use a clipped value loss, as in the original PPO implementation.
62
+ use_smooth_clamping : bool, default=False
63
+ Flag indicating whether to use exponential clamping on the value loos.
62
64
schedule : str, default="fixed"
63
65
Learning rate schedule mode ("fixed" or "adaptive" based on KL divergence).
64
66
desired_kl : float, default=0.01
@@ -90,6 +92,7 @@ def __init__(
90
92
schedule : str = "fixed" ,
91
93
desired_kl : float = 0.01 ,
92
94
amp_replay_buffer_size : int = 100000 ,
95
+ use_smooth_clamping : bool = False ,
93
96
device : str = "cpu" ,
94
97
) -> None :
95
98
# Set device and learning hyperparameters
@@ -146,6 +149,7 @@ def __init__(
146
149
self .lam : float = lam
147
150
self .max_grad_norm : float = max_grad_norm
148
151
self .use_clipped_value_loss : bool = use_clipped_value_loss
152
+ self .use_smooth_clamped_loss = use_smooth_clamped_loss
149
153
150
154
def init_storage (
151
155
self ,
@@ -452,10 +456,22 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
452
456
ratio = torch .exp (
453
457
actions_log_prob_batch - torch .squeeze (old_actions_log_prob_batch )
454
458
)
459
+
460
+ min_ = 1.0 - self .clip_param
461
+ max_ = 1.0 + self .clip_param
462
+
463
+ if self .use_smooth_clamping :
464
+ clipped_ratio = (
465
+ 1
466
+ / (1 + torch .exp ((- (ratio - min_ ) / (max_ - min_ ) + 0.5 ) * 4 ))
467
+ * (max_ - min_ )
468
+ + min_
469
+ )
470
+ else :
471
+ clipped_ratio = torch .clamp (ratio , min_ , max_ )
472
+
455
473
surrogate = - torch .squeeze (advantages_batch ) * ratio
456
- surrogate_clipped = - torch .squeeze (advantages_batch ) * torch .clamp (
457
- ratio , 1.0 - self .clip_param , 1.0 + self .clip_param
458
- )
474
+ surrogate_clipped = - torch .squeeze (advantages_batch ) * clipped_ratio
459
475
surrogate_loss = torch .max (surrogate , surrogate_clipped ).mean ()
460
476
461
477
# Compute the value function loss.
0 commit comments