From 4bdd5094c5f6b50435df58ee637dccf6ef9f5e09 Mon Sep 17 00:00:00 2001 From: Chaitanya Narisetty Date: Sat, 5 Jun 2021 10:02:56 -0400 Subject: [PATCH] added conv2d1, conv2d2 --- .../nets/pytorch_backend/conformer/encoder.py | 20 ++- .../pytorch_backend/transformer/argument.py | 2 +- .../pytorch_backend/transformer/encoder.py | 12 +- .../transformer/subsampling.py | 118 ++++++++++++++++++ 4 files changed, 148 insertions(+), 4 deletions(-) diff --git a/espnet/nets/pytorch_backend/conformer/encoder.py b/espnet/nets/pytorch_backend/conformer/encoder.py index 980d15a18b8..47ffcb96935 100644 --- a/espnet/nets/pytorch_backend/conformer/encoder.py +++ b/espnet/nets/pytorch_backend/conformer/encoder.py @@ -30,6 +30,8 @@ ) from espnet.nets.pytorch_backend.transformer.repeat import repeat from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling1 +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2 class Encoder(torch.nn.Module): @@ -112,6 +114,22 @@ def __init__( torch.nn.Dropout(dropout_rate), pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "conv2d1": + self.embed = Conv2dSubsampling1( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 1 + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 2 elif input_layer == "conv2d": self.embed = Conv2dSubsampling( idim, @@ -231,7 +249,7 @@ def forward(self, xs, masks): torch.Tensor: Mask tensor (#batch, time). """ - if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + if isinstance(self.embed, (Conv2dSubsampling1, Conv2dSubsampling2, Conv2dSubsampling, VGG2L)): xs, masks = self.embed(xs, masks) else: xs = self.embed(xs) diff --git a/espnet/nets/pytorch_backend/transformer/argument.py b/espnet/nets/pytorch_backend/transformer/argument.py index 216a68d90c3..3b8f0de0a76 100644 --- a/espnet/nets/pytorch_backend/transformer/argument.py +++ b/espnet/nets/pytorch_backend/transformer/argument.py @@ -26,7 +26,7 @@ def add_arguments_transformer_common(group): "--transformer-input-layer", type=str, default="conv2d", - choices=["conv2d", "linear", "embed"], + choices=["conv2d", "conv2d1", "conv2d2", "linear", "embed"], help="transformer input layer type", ) group.add_argument( diff --git a/espnet/nets/pytorch_backend/transformer/encoder.py b/espnet/nets/pytorch_backend/transformer/encoder.py index 5b19ded7dde..a7dc3d3009d 100644 --- a/espnet/nets/pytorch_backend/transformer/encoder.py +++ b/espnet/nets/pytorch_backend/transformer/encoder.py @@ -23,6 +23,8 @@ ) from espnet.nets.pytorch_backend.transformer.repeat import repeat from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling1 +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2 from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 @@ -103,6 +105,12 @@ def __init__( torch.nn.ReLU(), pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "conv2d1": + self.embed = Conv2dSubsampling1(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 1 + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 2 elif input_layer == "conv2d": self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate) self.conv_subsampling_factor = 4 @@ -292,7 +300,7 @@ def forward(self, xs, masks): """ if isinstance( self.embed, - (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L), + (Conv2dSubsampling1, Conv2dSubsampling2, Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L), ): xs, masks = self.embed(xs, masks) else: @@ -316,7 +324,7 @@ def forward_one_step(self, xs, masks, cache=None): List[torch.Tensor]: List of new cache tensors. """ - if isinstance(self.embed, Conv2dSubsampling): + if isinstance(self.embed, Conv2dSubsampling1, Conv2dSubsampling2, Conv2dSubsampling): xs, masks = self.embed(xs, masks) else: xs = self.embed(xs) diff --git a/espnet/nets/pytorch_backend/transformer/subsampling.py b/espnet/nets/pytorch_backend/transformer/subsampling.py index 1f5a736d3aa..6a61a4a25ea 100644 --- a/espnet/nets/pytorch_backend/transformer/subsampling.py +++ b/espnet/nets/pytorch_backend/transformer/subsampling.py @@ -39,6 +39,124 @@ def check_short_utt(ins, size): return False, -1 +class Conv2dSubsampling1(torch.nn.Module): + """Convolutional 2D subsampling (to same length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling1, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 1, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 1, 1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 0) // 1 - 0) // 1), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 1. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 1. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, ::1][:, :, ::1] + + def __getitem__(self, key): + """Get item. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling2(torch.nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling2, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 1, 1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 0) // 1), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, ::1] + + def __getitem__(self, key): + """Get item. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + class Conv2dSubsampling(torch.nn.Module): """Convolutional 2D subsampling (to 1/4 length).