Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added conv2d1, conv2d2 #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion espnet/nets/pytorch_backend/conformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling1
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2


class Encoder(torch.nn.Module):
Expand Down Expand Up @@ -112,6 +114,22 @@ def __init__(
torch.nn.Dropout(dropout_rate),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
self.conv_subsampling_factor = 1
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
self.conv_subsampling_factor = 2
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
idim,
Expand Down Expand Up @@ -231,7 +249,7 @@ def forward(self, xs, masks):
torch.Tensor: Mask tensor (#batch, time).

"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
if isinstance(self.embed, (Conv2dSubsampling1, Conv2dSubsampling2, Conv2dSubsampling, VGG2L)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
Expand Down
2 changes: 1 addition & 1 deletion espnet/nets/pytorch_backend/transformer/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def add_arguments_transformer_common(group):
"--transformer-input-layer",
type=str,
default="conv2d",
choices=["conv2d", "linear", "embed"],
choices=["conv2d", "conv2d1", "conv2d2", "linear", "embed"],
help="transformer input layer type",
)
group.add_argument(
Expand Down
12 changes: 10 additions & 2 deletions espnet/nets/pytorch_backend/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling1
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling2
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8

Expand Down Expand Up @@ -103,6 +105,12 @@ def __init__(
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 1
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 2
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 4
Expand Down Expand Up @@ -292,7 +300,7 @@ def forward(self, xs, masks):
"""
if isinstance(
self.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
(Conv2dSubsampling1, Conv2dSubsampling2, Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
):
xs, masks = self.embed(xs, masks)
else:
Expand All @@ -316,7 +324,7 @@ def forward_one_step(self, xs, masks, cache=None):
List[torch.Tensor]: List of new cache tensors.

"""
if isinstance(self.embed, Conv2dSubsampling):
if isinstance(self.embed, Conv2dSubsampling1, Conv2dSubsampling2, Conv2dSubsampling):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
Expand Down
118 changes: 118 additions & 0 deletions espnet/nets/pytorch_backend/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,124 @@ def check_short_utt(ins, size):
return False, -1


class Conv2dSubsampling1(torch.nn.Module):
"""Convolutional 2D subsampling (to same length).

Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.

"""

def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 0) // 1 - 0) // 1), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)

def forward(self, x, x_mask):
"""Subsample x.

Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).

Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 1.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 1.

"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, ::1][:, :, ::1]

def __getitem__(self, key):
"""Get item.

When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.

"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]


class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).

Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.

"""

def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling2, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 0) // 1), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)

def forward(self, x, x_mask):
"""Subsample x.

Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).

Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.

"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, ::1]

def __getitem__(self, key):
"""Get item.

When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.

"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]


class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).

Expand Down