From 2733fb9c317f280a655e986b213b6d15f2f75f55 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 6 Dec 2024 20:43:22 +0000 Subject: [PATCH] chore(format): run black on main --- rvc/train/losses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rvc/train/losses.py b/rvc/train/losses.py index 14beaec5..565ee7e8 100644 --- a/rvc/train/losses.py +++ b/rvc/train/losses.py @@ -23,7 +23,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): Args: disc_real_outputs (list of torch.Tensor): List of discriminator outputs for real samples. disc_generated_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. - """ + """ r_losses = [(1 - dr).pow(2).mean() for dr in disc_real_outputs] g_losses = [dg.pow(2).mean() for dg in disc_generated_outputs] loss = sum(r_losses) + sum(g_losses) @@ -36,7 +36,7 @@ def generator_loss(disc_outputs): Args: disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. - """ + """ gen_losses = [(1 - dg).pow(2).mean() for dg in disc_outputs] loss = sum(gen_losses) return loss, gen_losses @@ -64,7 +64,7 @@ def discriminator_loss_scaled(disc_real, disc_fake, scale=1.0): return loss, None, None -def generator_loss_scaled(disc_outputs, scale=1.0): +def generator_loss_scaled(disc_outputs, scale=1.0): """ Compute the scaled generator loss based on discriminator outputs.