-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvi_speaker_single.py
109 lines (88 loc) · 3.33 KB
/
vi_speaker_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import re
import json
import fsspec
import torch
import numpy as np
import argparse
from argparse import RawTextHelpFormatter
from speaker_encoder.models.lstm import LSTMSpeakerEncoder
from speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
from utils.audio import AudioProcessor
def read_json(json_path):
config_dict = {}
try:
with fsspec.open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
except json.decoder.JSONDecodeError:
# backwards compat.
data = read_json_with_comments(json_path)
config_dict.update(data)
return config_dict
def read_json_with_comments(json_path):
"""for backward compat."""
# fallback to json
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r"\\\n", "", input_str)
input_str = re.sub(r"//.*\n", "\n", input_str)
data = json.loads(input_str)
return data
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each wav file in a dataset.""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
parser.add_argument(
"config_path",
type=str,
help="Path to model config file.",
)
parser.add_argument("-s", "--source", help="input wave", dest="source")
parser.add_argument(
"-t", "--target", help="output 256d speaker embeddimg", dest="target"
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
args = parser.parse_args()
source_file = args.source
target_file = args.target
# config
config_dict = read_json(args.config_path)
# print(config_dict)
# model
config = SpeakerEncoderConfig(config_dict)
config.from_dict(config_dict)
speaker_encoder = LSTMSpeakerEncoder(
config.model_params["input_dim"],
config.model_params["proj_dim"],
config.model_params["lstm_dim"],
config.model_params["num_lstm_layers"],
)
speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)
# preprocess
speaker_encoder_ap = AudioProcessor(**config.audio)
# normalize the input audio level and trim silences
speaker_encoder_ap.do_sound_norm = True
speaker_encoder_ap.do_trim_silence = True
# compute speaker embeddings
# extract the embedding
waveform = speaker_encoder_ap.load_wav(
source_file, sr=speaker_encoder_ap.sample_rate
)
spec = speaker_encoder_ap.melspectrogram(waveform)
spec = torch.from_numpy(spec.T)
if args.use_cuda:
spec = spec.cuda()
spec = spec.unsqueeze(0)
embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()
embed = embed.squeeze()
# print(embed)
# print(embed.size)
np.save(target_file, embed, allow_pickle=False)
if hasattr(speaker_encoder, 'module'):
state_dict = speaker_encoder.module.state_dict()
else:
state_dict = speaker_encoder.state_dict()
torch.save({'model': state_dict}, "model_small.pth")