diff --git a/rvc/lib/algorithm/residuals.py b/rvc/lib/algorithm/residuals.py index 7a934285..f970657c 100644 --- a/rvc/lib/algorithm/residuals.py +++ b/rvc/lib/algorithm/residuals.py @@ -1,5 +1,6 @@ -from typing import Optional, Tuple import torch +from itertools import chain +from typing import Optional, Tuple from torch.nn.utils import remove_weight_norm from torch.nn.utils.parametrizations import weight_norm @@ -26,41 +27,70 @@ def apply_mask(tensor, mask): return tensor * mask if mask is not None else tensor -class ResBlockBase(torch.nn.Module): - def __init__(self, channels: int, kernel_size: int, dilations: Tuple[int]): - super(ResBlockBase, self).__init__() - self.convs1 = torch.nn.ModuleList( +class ResBlock(torch.nn.Module): + """ + A residual block module that applies a series of 1D convolutional layers with residual connections. + """ + + def __init__( + self, channels: int, kernel_size: int = 3, dilations: Tuple[int] = (1, 3, 5) + ): + """ + Initializes the ResBlock. + + Args: + channels (int): Number of input and output channels for the convolution layers. + kernel_size (int): Size of the convolution kernel. Defaults to 3. + dilations (Tuple[int]): Tuple of dilation rates for the convolution layers in the first set. + """ + super().__init__() + # Create convolutional layers with specified dilations and initialize weights + self.convs1 = self._create_convs(channels, kernel_size, dilations) + self.convs2 = self._create_convs(channels, kernel_size, [1] * len(dilations)) + + @staticmethod + def _create_convs( + channels: int, kernel_size: int, dilations: Tuple[int] + ): + """ + Creates a list of 1D convolutional layers with specified dilations. + + Args: + channels (int): Number of input and output channels for the convolution layers. + kernel_size (int): Size of the convolution kernel. + dilations (Tuple[int]): Tuple of dilation rates for each convolution layer. + """ + layers = torch.nn.ModuleList( [create_conv1d_layer(channels, kernel_size, d) for d in dilations] ) - self.convs1.apply(init_weights) + layers.apply(init_weights) + return layers - self.convs2 = torch.nn.ModuleList( - [create_conv1d_layer(channels, kernel_size, 1) for _ in dilations] - ) - self.convs2.apply(init_weights) - - def forward(self, x, x_mask=None): - for c1, c2 in zip(self.convs1, self.convs2): - xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) - xt = apply_mask(xt, x_mask) - xt = torch.nn.functional.leaky_relu(c1(xt), LRELU_SLOPE) - xt = apply_mask(xt, x_mask) - xt = c2(xt) - x = xt + x + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): + """Forward pass. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, sequence_length). + x_mask (torch.Tensor, optional): Optional mask to apply to the input and output tensors. + """ + for conv1, conv2 in zip(self.convs1, self.convs2): + x_residual = x + x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + x = apply_mask(x, x_mask) + x = torch.nn.functional.leaky_relu(conv1(x), LRELU_SLOPE) + x = apply_mask(x, x_mask) + x = conv2(x) + x = x + x_residual return apply_mask(x, x_mask) def remove_weight_norm(self): - for conv in self.convs1 + self.convs2: + """ + Removes weight normalization from all convolutional layers in the block. + """ + for conv in chain(self.convs1, self.convs2): remove_weight_norm(conv) -class ResBlock(ResBlockBase): - def __init__( - self, channels: int, kernel_size: int = 3, dilation: Tuple[int] = (1, 3, 5) - ): - super(ResBlock, self).__init__(channels, kernel_size, dilation) - - class Flip(torch.nn.Module): """Flip module for flow-based models. @@ -115,7 +145,7 @@ def __init__( self.gin_channels = gin_channels self.flows = torch.nn.ModuleList() - for i in range(n_flows): + for _ in range(n_flows): self.flows.append( ResidualCouplingLayer( channels,