Skip to content

Add modified Wasserstein loss in the Discriminator #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

Giulero
Copy link
Contributor

@Giulero Giulero commented Jul 17, 2025

This pull request introduces support for a new loss type (Wasserstein loss).

Note that the Wasserstein GAN formulation minimizes as loss

$$ L = D(fake) - D(true). $$

Empirically, I found out that if the last layer of the discriminator is a tanh(), the training is smoother and more stable.

The reward is then

$$ r = w \cdot \text{exp}(\eta \cdot D(fake)). $$

AMP PPO Algorithm Updates:

  • Removed redundant methods discriminator_policy_loss and discriminator_expert_loss from amp_ppo.py. Loss computation is now delegated to the Discriminator class for better encapsulation.
  • Updated the update method in amp_ppo.py to use the new compute_loss method from the Discriminator class, simplifying the logic for AMP loss and gradient penalty computation.

Discriminator Enhancements:

  • Added support for a new loss type, "Wasserstein", alongside the existing "BCEWithLogits". Introduced parameters loss_type and eta_wgan in the Discriminator class to configure the loss function. [1] [2]
  • Refactored the Discriminator class to include modular methods for policy_loss, expert_loss, and a unified compute_loss method. This centralizes loss computation and supports both loss types.
  • Reintroduced and updated the compute_grad_pen method to handle gradient penalty computation for both BCE and Wasserstein loss types. Added a new wgan_loss method for Wasserstein loss.

Runner Updates:

  • Updated the amp_on_policy_runner.py to pass the loss_type configuration to the Discriminator during initialization, enabling runtime selection of the loss function.
  • Minor formatting fix in the load method for improved readability.

Copilot

This comment was marked as outdated.

Giulero and others added 2 commits July 17, 2025 17:13
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
@Giulero Giulero requested a review from Copilot July 17, 2025 15:16
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new Wasserstein-style loss option for the AMP Discriminator and centralizes loss computation within the Discriminator class, simplifying the AMP PPO update logic.

  • Add loss_type and eta_wgan parameters to support both BCEWithLogits and modified Wasserstein losses.
  • Refactor discriminator loss and gradient‐penalty computations into Discriminator.compute_loss.
  • Update the on-policy runner and amp_ppo algorithm to use the new loss API.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
amp_rsl_rl/runners/amp_on_policy_runner.py Migrate to named args in Discriminator init, pass loss_type from config.
amp_rsl_rl/networks/discriminator.py Implement loss_type/eta_wgan, modularize policy/expert/WGAN losses, update gradient penalty.
amp_rsl_rl/algorithms/amp_ppo.py Remove redundant discriminator loss methods and call compute_loss instead.
Comments suppressed due to low confidence (3)

amp_rsl_rl/runners/amp_on_policy_runner.py:175

  • The eta_wgan parameter is not passed from the runner’s configuration to the Discriminator, preventing users from customizing the WGAN scaling. Consider adding eta_wgan=self.discriminator_cfg.get("eta_wgan", 0.3) when initializing the Discriminator.
            loss_type=self.discriminator_cfg["loss_type"],

amp_rsl_rl/networks/discriminator.py:32

  • The class docstring is not updated to reflect the new loss_type and eta_wgan constructor parameters. Please update the docstring so users know how to configure these options.
        loss_type: str = "BCEWithLogits",

amp_rsl_rl/networks/discriminator.py:182

  • [nitpick] Consider renaming the lambda_ parameter to a more descriptive name (e.g., penalty_coeff or gp_lambda) to improve readability and avoid confusion with the Python keyword lambda.
        lambda_: float = 10,

elif self.loss_type == "Wasserstein":
self.loss_fun = None
self.eta_wgan = eta_wgan
print("The Wasserstein-like loss is experimental")
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Use a logger (e.g., Python's logging module) or warnings.warn instead of print for runtime warnings to ensure consistent logging behavior and better control over output formatting.

Suggested change
print("The Wasserstein-like loss is experimental")
warnings.warn("The Wasserstein-like loss is experimental", UserWarning)

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant