-
Notifications
You must be signed in to change notification settings - Fork 9
Add modified Wasserstein loss in the Discriminator #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DIvide over disc output std the reward (to stabilize)
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new Wasserstein-style loss option for the AMP Discriminator and centralizes loss computation within the Discriminator
class, simplifying the AMP PPO update logic.
- Add
loss_type
andeta_wgan
parameters to support both BCEWithLogits and modified Wasserstein losses. - Refactor discriminator loss and gradient‐penalty computations into
Discriminator.compute_loss
. - Update the on-policy runner and
amp_ppo
algorithm to use the new loss API.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
amp_rsl_rl/runners/amp_on_policy_runner.py | Migrate to named args in Discriminator init, pass loss_type from config. |
amp_rsl_rl/networks/discriminator.py | Implement loss_type /eta_wgan , modularize policy/expert/WGAN losses, update gradient penalty. |
amp_rsl_rl/algorithms/amp_ppo.py | Remove redundant discriminator loss methods and call compute_loss instead. |
Comments suppressed due to low confidence (3)
amp_rsl_rl/runners/amp_on_policy_runner.py:175
- The
eta_wgan
parameter is not passed from the runner’s configuration to the Discriminator, preventing users from customizing the WGAN scaling. Consider addingeta_wgan=self.discriminator_cfg.get("eta_wgan", 0.3)
when initializing the Discriminator.
loss_type=self.discriminator_cfg["loss_type"],
amp_rsl_rl/networks/discriminator.py:32
- The class docstring is not updated to reflect the new
loss_type
andeta_wgan
constructor parameters. Please update the docstring so users know how to configure these options.
loss_type: str = "BCEWithLogits",
amp_rsl_rl/networks/discriminator.py:182
- [nitpick] Consider renaming the
lambda_
parameter to a more descriptive name (e.g.,penalty_coeff
orgp_lambda
) to improve readability and avoid confusion with the Python keywordlambda
.
lambda_: float = 10,
elif self.loss_type == "Wasserstein": | ||
self.loss_fun = None | ||
self.eta_wgan = eta_wgan | ||
print("The Wasserstein-like loss is experimental") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Use a logger (e.g., Python's logging
module) or warnings.warn
instead of print
for runtime warnings to ensure consistent logging behavior and better control over output formatting.
print("The Wasserstein-like loss is experimental") | |
warnings.warn("The Wasserstein-like loss is experimental", UserWarning) |
Copilot uses AI. Check for mistakes.
This pull request introduces support for a new loss type (Wasserstein loss).
Note that the Wasserstein GAN formulation minimizes as loss
Empirically, I found out that if the last layer of the discriminator is a tanh(), the training is smoother and more stable.
The reward is then
AMP PPO Algorithm Updates:
discriminator_policy_loss
anddiscriminator_expert_loss
fromamp_ppo.py
. Loss computation is now delegated to theDiscriminator
class for better encapsulation.update
method inamp_ppo.py
to use the newcompute_loss
method from theDiscriminator
class, simplifying the logic for AMP loss and gradient penalty computation.Discriminator Enhancements:
loss_type
andeta_wgan
in theDiscriminator
class to configure the loss function. [1] [2]Discriminator
class to include modular methods forpolicy_loss
,expert_loss
, and a unifiedcompute_loss
method. This centralizes loss computation and supports both loss types.compute_grad_pen
method to handle gradient penalty computation for both BCE and Wasserstein loss types. Added a newwgan_loss
method for Wasserstein loss.Runner Updates:
amp_on_policy_runner.py
to pass theloss_type
configuration to theDiscriminator
during initialization, enabling runtime selection of the loss function.load
method for improved readability.