-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
66 lines (48 loc) · 2.65 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import argparse
import torch
import torchaudio
from voiceldm import VoiceLDMPipeline
def parse_args(parser):
parser.add_argument('--desc_prompt', '-d', type=str, help='description prompt')
parser.add_argument('--cont_prompt', '-c', type=str, help='content prompt')
parser.add_argument('--audio_prompt', '-a', type=str, help='path to audio file to be used as audio prompt')
parser.add_argument("--model_config", type=str, default="m", help="configuration for VoiceLDM. 'm' for VoiceLDM-M and 's' for VoiceLDM-S")
parser.add_argument("--ckpt_path", type=str, help="checkpoint file path for VoiceLDM")
parser.add_argument("--output_dir", type=str, default="./outputs", help="directory to save generated audio")
parser.add_argument("--file_name", type=str, help="filename for the generated audio")
parser.add_argument('--num_inference_steps', type=int, default=50, help='number of inference steps for DDIM sampling')
parser.add_argument('--audio_length_in_s', type=float, default=10, help='duration of the audio for generation')
parser.add_argument('--guidance_scale', type=float, help='guidance weight for single classifier-free guidance')
parser.add_argument('--desc_guidance_scale', type=float, default=5, required=False, help='desc guidance weight for dual classifier-free guidance')
parser.add_argument('--cont_guidance_scale', type=float, default=7, required=False, help='cont guidance weight for dual classifier-free guidance')
parser.add_argument("--device", type=str, default="auto", help="device to use for audio generation")
parser.add_argument('--seed', type=int, help='random seed for generation')
return parser.parse_args()
def main():
parser = argparse.ArgumentParser()
args = parse_args(parser)
if args.device == "auto":
if torch.cuda.is_available():
args.device = torch.device("cuda:0")
else:
args.device = torch.device("cpu")
elif args.device is not None:
args.device = torch.device(device)
else:
args.device = torch.device("cpu")
pipe = VoiceLDMPipeline(
args.model_config,
args.ckpt_path,
args.device,
)
audio = pipe(**vars(args))
file_name = args.file_name
if file_name is None:
file_name = args.desc_prompt or args.audio_prompt
file_name = file_name + "-" + args.cont_prompt + ".wav"
os.makedirs(args.output_dir, exist_ok=True)
save_path = "/ssd6/other/liangzq02/code2/VoiceLDM/outputs/clone1.wav"
torchaudio.save(save_path, src=audio, sample_rate=32000)
if __name__ == "__main__":
main()