-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Flow Mirror-s model inference code (#3)
* chore: Update README.md * Minor Update README.md * Update README.md for error fix. * chore: Update README.md * Update README.md for demo site deployment and formatting improvements * chore: Update README.md for formatting improvements * Update README.md for consistent formatting and language improvements * Update README.md for consistent formatting and demo site URL * Update flow_mirror-s model inference code. * Update README.md for release inference code.
- Loading branch information
1 parent
bd8fbc3
commit 344c185
Showing
30 changed files
with
4,533 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
__version__ = "0.1" | ||
|
||
|
||
from .configuration_flow_mirror import FlowmirrorConfig, FlowmirrorDecoderConfig | ||
from .modeling_flow_mirror import ( | ||
FlowmirrorForCausalLM, | ||
FlowmirrorForConditionalGeneration, | ||
apply_delay_pattern_mask, | ||
build_delay_pattern_mask, | ||
) | ||
|
||
from .dac_wrapper import DACConfig, DACModel | ||
from transformers import AutoConfig, AutoModel | ||
|
||
AutoConfig.register("dac", DACConfig) | ||
AutoModel.register(DACConfig, DACModel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from torch import nn | ||
|
||
class AdapterMLP(nn.Module): | ||
def __init__(self, hidden_size, intermediate_size, output_size): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
self.intermediate_size = intermediate_size | ||
self.output_size = output_size | ||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.down_proj = nn.Linear(self.intermediate_size, self.output_size, bias=False) | ||
self.act_fn = nn.GELU() | ||
|
||
def forward(self, x): | ||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
113 changes: 113 additions & 0 deletions
113
flow_mirror_s/flow_mirror_model/configuration_flow_mirror.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# coding=utf-8 | ||
from transformers import AutoConfig, logging | ||
from transformers.configuration_utils import PretrainedConfig | ||
from flow_mirror_model.speaker_rec_cam import SpeakerRecCAMPP, SpeakerRecCAMPPConfig | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class FlowmirrorDecoderConfig(PretrainedConfig): | ||
|
||
model_type = "flow_mirror_decoder" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
|
||
def __init__( | ||
self, | ||
vocab_size=2049, | ||
prompt_vocab_size=152456, | ||
max_position_embeddings=2048, | ||
num_hidden_layers=24, | ||
ffn_dim=4096, | ||
num_attention_heads=16, | ||
layerdrop=0.0, | ||
use_cache=True, | ||
activation_function="gelu", | ||
hidden_size=1024, | ||
dropout=0.1, | ||
attention_dropout=0.0, | ||
activation_dropout=0.0, | ||
initializer_factor=0.02, | ||
scale_embedding=False, | ||
num_codebooks=4, | ||
pad_token_id=2048, | ||
bos_token_id=2049, | ||
eos_token_id=2048, | ||
tie_word_embeddings=False, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
self.prompt_vocab_size = prompt_vocab_size | ||
self.max_position_embeddings = max_position_embeddings | ||
self.hidden_size = hidden_size | ||
self.ffn_dim = ffn_dim | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.dropout = dropout | ||
self.attention_dropout = attention_dropout | ||
self.activation_dropout = activation_dropout | ||
self.activation_function = activation_function | ||
self.initializer_factor = initializer_factor | ||
self.layerdrop = layerdrop | ||
self.use_cache = use_cache | ||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | ||
self.num_codebooks = num_codebooks | ||
|
||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) | ||
|
||
|
||
class FlowmirrorConfig(PretrainedConfig): | ||
|
||
model_type = "flow_mirror" | ||
is_composition = True | ||
|
||
def __init__(self, vocab_size=1024, **kwargs): | ||
super().__init__(**kwargs) | ||
if "audio_encoder" not in kwargs or "decoder" not in kwargs or "speaker_encoder" not in kwargs: | ||
raise ValueError("Config has to be initialized with speaker_encoder, audio_encoder and decoder config") | ||
|
||
audio_encoder_config = kwargs.pop("audio_encoder") | ||
audio_encoder_model_type = audio_encoder_config.pop("model_type") | ||
|
||
speaker_encoder_config = kwargs.pop("speaker_encoder") | ||
|
||
decoder_config = kwargs.pop("decoder") | ||
|
||
self.vocab_size = vocab_size | ||
self.speaker_encoder = SpeakerRecCAMPPConfig(**speaker_encoder_config) | ||
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) | ||
self.decoder = FlowmirrorDecoderConfig(**decoder_config) | ||
self.is_encoder_decoder = False | ||
|
||
@classmethod | ||
def from_sub_models_config( | ||
cls, | ||
audio_encoder_config: PretrainedConfig, | ||
decoder_config: FlowmirrorDecoderConfig, | ||
speaker_encoder_config: SpeakerRecCAMPPConfig, | ||
**kwargs, | ||
): | ||
r""" | ||
Instantiate a [`FlowmirrorConfig`] (or a derived class) from text encoder, audio encoder and decoder | ||
configurations. | ||
Returns: | ||
[`FlowmirrorConfig`]: An instance of a configuration object | ||
""" | ||
|
||
return cls( | ||
audio_encoder=audio_encoder_config.to_dict(), | ||
speaker_encoder = speaker_encoder_config.to_dict(), | ||
decoder=decoder_config.to_dict(), | ||
**kwargs, | ||
) | ||
|
||
@property | ||
# This is a property because you might want to change the codec model on the fly | ||
def sampling_rate(self): | ||
return self.audio_encoder.sampling_rate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .configuration_dac import DACConfig | ||
from .modeling_dac import DACModel |
29 changes: 29 additions & 0 deletions
29
flow_mirror_s/flow_mirror_model/dac_wrapper/configuration_dac.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from transformers import PretrainedConfig | ||
from typing import List | ||
|
||
|
||
class DACConfig(PretrainedConfig): | ||
model_type = "dac" | ||
|
||
def __init__( | ||
self, | ||
num_codebooks: int = 9, | ||
encoder_rates: List[int] = [2, 4, 8, 8], | ||
decoder_rates: List[int] = [8, 8, 4, 2], | ||
model_bitrate: int = 8, # kbps | ||
codebook_size: int = 1024, | ||
latent_dim: int = 1024, | ||
frame_rate: int = 86, | ||
sampling_rate: int = 16000, | ||
**kwargs, | ||
): | ||
self.codebook_size = codebook_size | ||
self.encoder_rates = encoder_rates | ||
self.decoder_rates = decoder_rates | ||
self.model_bitrate = model_bitrate | ||
self.latent_dim = latent_dim | ||
self.num_codebooks = num_codebooks | ||
self.frame_rate = frame_rate | ||
self.sampling_rate = sampling_rate | ||
|
||
super().__init__(**kwargs) |
140 changes: 140 additions & 0 deletions
140
flow_mirror_s/flow_mirror_model/dac_wrapper/modeling_dac.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import torch | ||
|
||
from transformers import PreTrainedModel | ||
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput | ||
from .configuration_dac import DACConfig | ||
|
||
from dac.model import DAC | ||
|
||
|
||
# model doesn't support batching yet | ||
|
||
|
||
class DACModel(PreTrainedModel): | ||
config_class = DACConfig | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.model = DAC( | ||
n_codebooks=config.num_codebooks, | ||
encoder_rates=config.encoder_rates, | ||
decoder_rates=config.decoder_rates, | ||
latent_dim=config.latent_dim, | ||
codebook_size=config.codebook_size, | ||
sample_rate=config.sampling_rate, | ||
) | ||
|
||
def encode( | ||
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None, **kwargs | ||
): | ||
""" | ||
Encodes the input audio waveform into discrete codes. | ||
Args: | ||
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | ||
Float values of the input audio waveform. | ||
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | ||
Padding mask used to pad the `input_values`. | ||
bandwidth (`float`, *optional*): | ||
Not used, kept to have the same inferface as HF encodec. | ||
n_quantizers (`int`, *optional*) : | ||
Number of quantizers to use, by default None | ||
If None, all quantizers are used. | ||
sample_rate (`int`, *optional*) : | ||
Signal sampling_rate | ||
Returns: | ||
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling | ||
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with | ||
`codebook` of shape `[batch_size, num_codebooks, frames]`. | ||
Scale is not used here. | ||
""" | ||
_, channels, input_length = input_values.shape | ||
|
||
if channels < 1 or channels > 2: | ||
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") | ||
|
||
audio_data = self.model.preprocess(input_values, sample_rate) | ||
|
||
return_dict = return_dict if return_dict is not None else self.config.return_dict | ||
|
||
# TODO: for now, no chunk length | ||
|
||
chunk_length = None # self.config.chunk_length | ||
if chunk_length is None: | ||
chunk_length = input_length | ||
stride = input_length | ||
else: | ||
stride = self.config.chunk_stride | ||
|
||
if padding_mask is None: | ||
padding_mask = torch.ones_like(input_values).bool() | ||
|
||
encoded_frames = [] | ||
scales = [] | ||
|
||
step = chunk_length - stride | ||
if (input_length % stride) - step != 0: | ||
raise ValueError( | ||
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly." | ||
) | ||
|
||
for offset in range(0, input_length - step, stride): | ||
mask = padding_mask[..., offset : offset + chunk_length].bool() | ||
frame = audio_data[:, :, offset : offset + chunk_length] | ||
|
||
scale = None | ||
|
||
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers) | ||
encoded_frames.append(encoded_frame) | ||
scales.append(scale) | ||
|
||
encoded_frames = torch.stack(encoded_frames) | ||
|
||
if not return_dict: | ||
return (encoded_frames, scales) | ||
|
||
return EncodecEncoderOutput(encoded_frames, scales) | ||
|
||
def decode( | ||
self, | ||
audio_codes, | ||
audio_scales, | ||
padding_mask=None, | ||
return_dict=None, | ||
**kwargs, | ||
): | ||
""" | ||
Decodes the given frames into an output audio waveform. | ||
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be | ||
trimmed. | ||
Args: | ||
audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*): | ||
Discret code embeddings computed using `model.encode`. | ||
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*): | ||
Not used, kept to have the same inferface as HF encodec. | ||
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | ||
Padding mask used to pad the `input_values`. | ||
Not used yet, kept to have the same inferface as HF encodec. | ||
return_dict (`bool`, *optional*): | ||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
""" | ||
return_dict = return_dict or self.config.return_dict | ||
|
||
# TODO: for now, no chunk length | ||
|
||
if len(audio_codes) != 1: | ||
raise ValueError(f"Expected one frame, got {len(audio_codes)}") | ||
|
||
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0] | ||
audio_values = self.model.decode(audio_values) | ||
if not return_dict: | ||
return (audio_values,) | ||
return EncodecDecoderOutput(audio_values) | ||
|
||
def forward(self, tensor): | ||
raise ValueError(f"`DACModel.forward` not implemented yet") |
Oops, something went wrong.