diff --git a/README.md b/README.md index 94f2dff..3d508bc 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,82 @@ # VI-Speaker Speaker embedding for VI-SVC and VI-SVS, alse for VITS; Use this to replace the ID to implement voice clone. + +# code from mozill_tts and Coqpit/TTS +https://github.com/mozilla/TTS/tree/master/TTS/speaker_encoder + +pip install coqpit + +# download model,or get it at **release** +https://github.com/mozilla/TTS/wiki/Released-Models + +Speaker-Encoder by @mueller91 LibriTTS + VCTK + VoxCeleb + CommonVoice + +# please read the config +https://drive.google.com/drive/folders/15oeBYf6Qn1edONkVLXe82MzdIi3O_9m3 + +# use +python vi_speaker_single.py ./saved_models/best_model.pth.tar ./saved_models/config.json -s TEST.wav -t TEST.npy + +# batch use +python vi_speaker_batch.py ./saved_models/best_model.pth.tar ./saved_models/config.json ./data/waves ./speaker_embedding + +data/ +└── waves + ├── spk1 + │   ├── 000002.wav + │   ├── 000006.wav + │   └── 000038.wav + └── spk2 + ├── 000040.wav + ├── 000044.wav + └── 000077.wav + +speaker_embedding/ +├── spk1 +│   ├── 000002.npy +│   ├── 000006.npy +│   └── 000038.npy +└── spk2 + ├── 000040.npy + ├── 000044.npy + └── 000077.npy + +# compute speaker center +input path = speaker_embedding, output path = speaker_embedding_center + +python vi_speaker_center.py + +speaker_embedding_center/ +├── spk1.npy +└── spk2.npy + + +# for VI-SVC +mv speaker_embedding_center data/spkid + +data/ +├── waves +│   ├── 10001 +│   ├── 20400 +│   │   ├── 20400_001.wav +│   │   ├── 20456_019.wav +│   │   +├── phone +│   ├── 10001 +│   ├── 20400 +│   │   ├── 20400_001.npy +│   │   ├── 20456_019.npy +│   │   +├── lable +│   ├── 10001 +│   ├── 20400 +│   │   ├── 20400_001.npy +│   │   ├── 20456_019.npy +│   │   +├── spkid +│   ├── 10001.npy +│   ├── 20400.npy +│   │   + + + diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/shared_configs.py b/config/shared_configs.py new file mode 100644 index 0000000..d91bf2b --- /dev/null +++ b/config/shared_configs.py @@ -0,0 +1,342 @@ +from dataclasses import asdict, dataclass +from typing import List + +from coqpit import Coqpit, check_argument + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + name (str): + Dataset name that defines the preprocessor in use. Defaults to None. + + path (str): + Root path to the dataset files. Defaults to None. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to None. + + unused_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + name: str = "" + path: str = "" + meta_file_train: str = "" + ununsed_speakers: List[str] = None + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("name", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(Coqpit): + """Base config to define the basic training parameters that are shared + among all the models. + + Args: + model (str): + Name of the model that is used in the training. + + run_name (str): + Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. + + run_description (str): + Short description of the experiment. + + epochs (int): + Number training epochs. Defaults to 10000. + + batch_size (int): + Training batch size. + + eval_batch_size (int): + Validation batch size. + + mixed_precision (bool): + Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however + it may also cause numerical unstability in some cases. + + scheduler_after_epoch (bool): + If true, run the scheduler step after each epoch else run it after each model step. + + run_eval (bool): + Enable / Disable evaluation (validation) run. Defaults to True. + + test_delay_epochs (int): + Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful + results, hence waiting for a couple of epochs might save some time. + + print_eval (bool): + Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at + the end of the evaluation. Default to ```False```. + + print_step (int): + Number of steps required to print the next training log. + + log_dashboard (str): "tensorboard" or "wandb" + Set the experiment tracking tool + + plot_step (int): + Number of steps required to log training on Tensorboard. + + model_param_stats (bool): + Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. + Defaults to ```False```. + + project_name (str): + Name of the project. Defaults to config.model + + wandb_entity (str): + Name of W&B entity/team. Enables collaboration across a team or org. + + log_model_step (int): + Number of steps required to log a checkpoint as W&B artifact + + save_step (int):ipt + Number of steps required to save the next checkpoint. + + checkpoint (bool): + Enable / Disable checkpointing. + + keep_all_best (bool): + Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults + to ```False```. + + keep_after (int): + Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults + to 10000. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + + output_path (str): + Path for training output folder, either a local file path or other + URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or + S3 (s3://) paths. The nonexist part of the given path is created + automatically. All training artefacts are saved there. + """ + + model: str = None + run_name: str = "coqui_tts" + run_description: str = "" + # training params + epochs: int = 10000 + batch_size: int = None + eval_batch_size: int = None + mixed_precision: bool = False + scheduler_after_epoch: bool = False + # eval params + run_eval: bool = True + test_delay_epochs: int = 0 + print_eval: bool = False + # logging + dashboard_logger: str = "tensorboard" + print_step: int = 25 + plot_step: int = 100 + model_param_stats: bool = False + project_name: str = None + log_model_step: int = None + wandb_entity: str = None + # checkpointing + save_step: int = 10000 + checkpoint: bool = True + keep_all_best: bool = False + keep_after: int = 10000 + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False + # paths + output_path: str = None + # distributed + distributed_backend: str = "nccl" + distributed_url: str = "tcp://localhost:54321" diff --git a/saved_models/config.json b/saved_models/config.json new file mode 100644 index 0000000..e330aab --- /dev/null +++ b/saved_models/config.json @@ -0,0 +1,104 @@ +{ + "model_name": "lstm", + "run_name": "mueller91", + "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", + "audio":{ + // Audio processing parameters + "num_mels": 80, // size of the mel spec frame. + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + // Normalization parameters + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60 // threshold for timming silence. Set this according to your dataset. + }, + "reinit_layers": [], + "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) + "grad_clip": 3.0, // upper limit for gradients for clipping. + "epochs": 1000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_decay": false, // if true, Noam learning rate decaying is applied through training. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "steps_plot_stats": 10, // number of steps to plot embeddings. + "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "voice_len": 2.0, // size of the voice + "num_utters_per_speaker": 10, // + "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "wd": 0.000001, // Weight decay weight. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 20, // Number of steps to log traning on console. + "output_path": "../../OutputsMozilla/checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. + "model": { + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": true + }, + "storage": { + "sample_from_storage_p": 0.9, // the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 25, // the size of the in-memory storage with respect to a single batch + "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness + }, + "datasets": + [ + { + "name": "vctk_slim", + "path": "../../../audio-datasets/en/VCTK-Corpus/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-other-500", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb1", + "path": "../../../audio-datasets/en/voxceleb1/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb2", + "path": "../../../audio-datasets/en/voxceleb2/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "common_voice", + "path": "../../../audio-datasets/en/MozillaCommonVoice", + "meta_file_train": "train.tsv", + "meta_file_val": "test.tsv" + } + ] +} \ No newline at end of file diff --git a/speaker_encoder/README.md b/speaker_encoder/README.md new file mode 100644 index 0000000..b6f541f --- /dev/null +++ b/speaker_encoder/README.md @@ -0,0 +1,18 @@ +### Speaker Encoder + +This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. + +With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. + +Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). + +![](umap.png) + +Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. + +To run the code, you need to follow the same flow as in TTS. + +- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. +- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Watch training on Tensorboard as in TTS diff --git a/speaker_encoder/__init__.py b/speaker_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/speaker_encoder/configs/config.json b/speaker_encoder/configs/config.json new file mode 100644 index 0000000..e330aab --- /dev/null +++ b/speaker_encoder/configs/config.json @@ -0,0 +1,104 @@ +{ + "model_name": "lstm", + "run_name": "mueller91", + "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", + "audio":{ + // Audio processing parameters + "num_mels": 80, // size of the mel spec frame. + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + // Normalization parameters + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60 // threshold for timming silence. Set this according to your dataset. + }, + "reinit_layers": [], + "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) + "grad_clip": 3.0, // upper limit for gradients for clipping. + "epochs": 1000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_decay": false, // if true, Noam learning rate decaying is applied through training. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "steps_plot_stats": 10, // number of steps to plot embeddings. + "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "voice_len": 2.0, // size of the voice + "num_utters_per_speaker": 10, // + "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "wd": 0.000001, // Weight decay weight. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 20, // Number of steps to log traning on console. + "output_path": "../../OutputsMozilla/checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. + "model": { + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": true + }, + "storage": { + "sample_from_storage_p": 0.9, // the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 25, // the size of the in-memory storage with respect to a single batch + "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness + }, + "datasets": + [ + { + "name": "vctk_slim", + "path": "../../../audio-datasets/en/VCTK-Corpus/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-other-500", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb1", + "path": "../../../audio-datasets/en/voxceleb1/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb2", + "path": "../../../audio-datasets/en/voxceleb2/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "common_voice", + "path": "../../../audio-datasets/en/MozillaCommonVoice", + "meta_file_train": "train.tsv", + "meta_file_val": "test.tsv" + } + ] +} \ No newline at end of file diff --git a/speaker_encoder/dataset.py b/speaker_encoder/dataset.py new file mode 100644 index 0000000..6b2b0dd --- /dev/null +++ b/speaker_encoder/dataset.py @@ -0,0 +1,253 @@ +import random + +import numpy as np +import torch +from torch.utils.data import Dataset + +from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage + + +class SpeakerEncoderDataset(Dataset): + def __init__( + self, + ap, + meta_data, + voice_len=1.6, + num_speakers_in_batch=64, + storage_size=1, + sample_from_storage_p=0.5, + num_utter_per_speaker=10, + skip_speakers=False, + verbose=False, + augmentation_config=None, + ): + """ + Args: + ap (TTS.tts.utils.AudioProcessor): audio processor object. + meta_data (list): list of dataset instances. + seq_len (int): voice segment length in seconds. + verbose (bool): print diagnostic information. + """ + super().__init__() + self.items = meta_data + self.sample_rate = ap.sample_rate + self.seq_len = int(voice_len * self.sample_rate) + self.num_speakers_in_batch = num_speakers_in_batch + self.num_utter_per_speaker = num_utter_per_speaker + self.skip_speakers = skip_speakers + self.ap = ap + self.verbose = verbose + self.__parse_items() + storage_max_size = storage_size * num_speakers_in_batch + self.storage = Storage( + maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch + ) + self.sample_from_storage_p = float(sample_from_storage_p) + + speakers_aux = list(self.speakers) + speakers_aux.sort() + self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)} + + # Augmentation + self.augmentator = None + self.gaussian_augmentation_config = None + if augmentation_config: + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): + self.augmentator = AugmentWAV(ap, augmentation_config) + + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] + + if self.verbose: + print("\n > DataLoader initialization") + print(f" | > Speakers per Batch: {num_speakers_in_batch}") + print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters") + print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") + print(f" | > Number of instances : {len(self.items)}") + print(f" | > Sequence length: {self.seq_len}") + print(f" | > Num speakers: {len(self.speakers)}") + + def load_wav(self, filename): + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) + return audio + + def load_data(self, idx): + text, wav_file, speaker_name = self.items[idx] + wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) + mel = self.ap.melspectrogram(wav).astype("float32") + # sample seq_len + + assert text.size > 0, self.items[idx][1] + assert wav.size > 0, self.items[idx][1] + + sample = { + "mel": mel, + "item_idx": self.items[idx][1], + "speaker_name": speaker_name, + } + return sample + + def __parse_items(self): + self.speaker_to_utters = {} + for i in self.items: + path_ = i[1] + speaker_ = i[2] + if speaker_ in self.speaker_to_utters.keys(): + self.speaker_to_utters[speaker_].append(path_) + else: + self.speaker_to_utters[speaker_] = [ + path_, + ] + + if self.skip_speakers: + self.speaker_to_utters = { + k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker + } + + self.speakers = [k for (k, v) in self.speaker_to_utters.items()] + + def __len__(self): + return int(1e10) + + def get_num_speakers(self): + return len(self.speakers) + + def __sample_speaker(self, ignore_speakers=None): + speaker = random.sample(self.speakers, 1)[0] + # if list of speakers_id is provide make sure that it's will be ignored + if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers: + while True: + speaker = random.sample(self.speakers, 1)[0] + if self.speakerid_to_classid[speaker] not in ignore_speakers: + break + + if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): + utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) + else: + utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) + return speaker, utters + + def __sample_speaker_utterances(self, speaker): + """ + Sample all M utterances for the given speaker. + """ + wavs = [] + labels = [] + for _ in range(self.num_utter_per_speaker): + # TODO:dummy but works + while True: + # remove speakers that have num_utter less than 2 + if len(self.speaker_to_utters[speaker]) > 1: + utter = random.sample(self.speaker_to_utters[speaker], 1)[0] + else: + if speaker in self.speakers: + self.speakers.remove(speaker) + + speaker, _ = self.__sample_speaker() + continue + + wav = self.load_wav(utter) + if wav.shape[0] - self.seq_len > 0: + break + + if utter in self.speaker_to_utters[speaker]: + self.speaker_to_utters[speaker].remove(utter) + + if self.augmentator is not None and self.data_augmentation_p: + if random.random() < self.data_augmentation_p: + wav = self.augmentator.apply_one(wav) + + wavs.append(wav) + labels.append(self.speakerid_to_classid[speaker]) + return wavs, labels + + def __getitem__(self, idx): + speaker, _ = self.__sample_speaker() + speaker_id = self.speakerid_to_classid[speaker] + return speaker, speaker_id + + def __load_from_disk_and_storage(self, speaker): + # don't sample from storage, but from HDD + wavs_, labels_ = self.__sample_speaker_utterances(speaker) + # put the newly loaded item into storage + self.storage.append((wavs_, labels_)) + return wavs_, labels_ + + def collate_fn(self, batch): + # get the batch speaker_ids + batch = np.array(batch) + speakers_id_in_batch = set(batch[:, 1].astype(np.int32)) + + labels = [] + feats = [] + speakers = set() + + for speaker, speaker_id in batch: + speaker_id = int(speaker_id) + + # ensure that an speaker appears only once in the batch + if speaker_id in speakers: + + # remove current speaker + if speaker_id in speakers_id_in_batch: + speakers_id_in_batch.remove(speaker_id) + + speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch) + speaker_id = self.speakerid_to_classid[speaker] + speakers_id_in_batch.add(speaker_id) + + if random.random() < self.sample_from_storage_p and self.storage.full(): + # sample from storage (if full) + wavs_, labels_ = self.storage.get_random_sample_fast() + + # force choose the current speaker or other not in batch + # It's necessary for ideal training with AngleProto and GE2E losses + if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id: + attempts = 0 + while True: + wavs_, labels_ = self.storage.get_random_sample_fast() + if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch: + break + + attempts += 1 + # Try 5 times after that load from disk + if attempts >= 5: + wavs_, labels_ = self.__load_from_disk_and_storage(speaker) + break + else: + # don't sample from storage, but from HDD + wavs_, labels_ = self.__load_from_disk_and_storage(speaker) + + # append speaker for control + speakers.add(labels_[0]) + + # remove current speaker and append other + if speaker_id in speakers_id_in_batch: + speakers_id_in_batch.remove(speaker_id) + + speakers_id_in_batch.add(labels_[0]) + + # get a random subset of each of the wavs and extract mel spectrograms. + feats_ = [] + for wav in wavs_: + offset = random.randint(0, wav.shape[0] - self.seq_len) + wav = wav[offset : offset + self.seq_len] + # add random gaussian noise + if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]: + if random.random() < self.gaussian_augmentation_config["p"]: + wav += np.random.normal( + self.gaussian_augmentation_config["min_amplitude"], + self.gaussian_augmentation_config["max_amplitude"], + size=len(wav), + ) + mel = self.ap.melspectrogram(wav) + feats_.append(torch.FloatTensor(mel)) + + labels.append(torch.LongTensor(labels_)) + feats.extend(feats_) + + feats = torch.stack(feats) + labels = torch.stack(labels) + + return feats.transpose(1, 2), labels diff --git a/speaker_encoder/losses.py b/speaker_encoder/losses.py new file mode 100644 index 0000000..8ba917b --- /dev/null +++ b/speaker_encoder/losses.py @@ -0,0 +1,220 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +# adapted from https://github.com/cvqluu/GE2E-Loss +class GE2ELoss(nn.Module): + def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): + """ + Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector (e.g. d-vector) + Args: + - init_w (float): defines the initial value of w in Equation (5) of [1] + - init_b (float): definies the initial value of b in Equation (5) of [1] + """ + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.loss_method = loss_method + + print(" > Initialized Generalized End-to-End loss") + + assert self.loss_method in ["softmax", "contrast"] + + if self.loss_method == "softmax": + self.embed_loss = self.embed_loss_softmax + if self.loss_method == "contrast": + self.embed_loss = self.embed_loss_contrast + + # pylint: disable=R0201 + def calc_new_centroids(self, dvecs, centroids, spkr, utt): + """ + Calculates the new centroids excluding the reference utterance + """ + excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) + excl = torch.mean(excl, 0) + new_centroids = [] + for i, centroid in enumerate(centroids): + if i == spkr: + new_centroids.append(excl) + else: + new_centroids.append(centroid) + return torch.stack(new_centroids) + + def calc_cosine_sim(self, dvecs, centroids): + """ + Make the cosine similarity matrix with dims (N,M,N) + """ + cos_sim_matrix = [] + for spkr_idx, speaker in enumerate(dvecs): + cs_row = [] + for utt_idx, utterance in enumerate(speaker): + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) + # vector based cosine similarity for speed + cs_row.append( + torch.clamp( + torch.mm( + utterance.unsqueeze(1).transpose(0, 1), + new_centroids.transpose(0, 1), + ) + / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), + 1e-6, + ) + ) + cs_row = torch.cat(cs_row, dim=0) + cos_sim_matrix.append(cs_row) + return torch.stack(cos_sim_matrix) + + # pylint: disable=R0201 + def embed_loss_softmax(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by taking softmax + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + # pylint: disable=R0201 + def embed_loss_contrast(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + def forward(self, x, _label=None): + """ + Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + centroids = torch.mean(x, 1) + cos_sim_matrix = self.calc_cosine_sim(x, centroids) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = self.w * cos_sim_matrix + self.b + L = self.embed_loss(x, cos_sim_matrix) + return L.mean() + + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, init_w=10.0, init_b=-5.0): + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + print(" > Initialized Angular Prototypical loss") + + def forward(self, x, _label=None): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.arange(num_speakers).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L + + +class SoftmaxLoss(nn.Module): + """ + Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + """ + + def __init__(self, embedding_dim, n_speakers): + super().__init__() + + self.criterion = torch.nn.CrossEntropyLoss() + self.fc = nn.Linear(embedding_dim, n_speakers) + + print("Initialised Softmax Loss") + + def forward(self, x, label=None): + # reshape for compatibility + x = x.reshape(-1, x.size()[-1]) + label = label.reshape(-1) + + x = self.fc(x) + L = self.criterion(x, label) + + return L + + +class SoftmaxAngleProtoLoss(nn.Module): + """ + Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): + super().__init__() + + self.softmax = SoftmaxLoss(embedding_dim, n_speakers) + self.angleproto = AngleProtoLoss(init_w, init_b) + + print("Initialised SoftmaxAnglePrototypical Loss") + + def forward(self, x, label=None): + """ + Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + Lp = self.angleproto(x) + + Ls = self.softmax(x, label) + + return Ls + Lp diff --git a/speaker_encoder/models/lstm.py b/speaker_encoder/models/lstm.py new file mode 100644 index 0000000..7430e72 --- /dev/null +++ b/speaker_encoder/models/lstm.py @@ -0,0 +1,131 @@ +import numpy as np +import torch +from torch import nn + +from utils.io import load_fsspec + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(nn.Module): + def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x): + # TODO: implement state passing for lstms + d = self.layers(x) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d + + @torch.no_grad() + def inference(self, x): + d = self.layers.forward(x) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d + + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings + + def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): + """ + Generate embeddings for a batch of utterances + x: BxTxD + """ + num_overlap = num_frames * overlap + max_len = x.shape[1] + embed = None + num_iters = seq_lens / (num_frames - num_overlap) + cur_iter = 0 + for offset in range(0, max_len, num_frames - num_overlap): + cur_iter += 1 + end_offset = min(x.shape[1], offset + num_frames) + frames = x[:, offset:end_offset] + if embed is None: + embed = self.inference(frames) + else: + embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) + return embed / num_iters + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training diff --git a/speaker_encoder/models/resnet.py b/speaker_encoder/models/resnet.py new file mode 100644 index 0000000..fcc850d --- /dev/null +++ b/speaker_encoder/models/resnet.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +from torch import nn + +from TTS.utils.io import load_fsspec + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(nn.Module): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + x = x.transpose(1, 2) + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.forward(frames_batch, l2_norm=True) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings + + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training diff --git a/speaker_encoder/requirements.txt b/speaker_encoder/requirements.txt new file mode 100644 index 0000000..a486cc4 --- /dev/null +++ b/speaker_encoder/requirements.txt @@ -0,0 +1,2 @@ +umap-learn +numpy>=1.17.0 diff --git a/speaker_encoder/speaker_encoder_config.py b/speaker_encoder/speaker_encoder_config.py new file mode 100644 index 0000000..f953052 --- /dev/null +++ b/speaker_encoder/speaker_encoder_config.py @@ -0,0 +1,65 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import MISSING + +from config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class SpeakerEncoderConfig(BaseTrainingConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + storage: Dict = field( + default_factory=lambda: { + "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 15, # the size of the in-memory storage with respect to a single batch + } + ) + + # training params + max_train_step: int = 1000000 # end training when number of training steps reaches this value. + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + lr_decay: bool = False + warmup_steps: int = 4000 + wd: float = 1e-6 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + checkpoint: bool = True + save_step: int = 1000 + print_step: int = 20 + + # data loader + num_speakers_in_batch: int = MISSING + num_utters_per_speaker: int = MISSING + num_loader_workers: int = MISSING + skip_speakers: bool = False + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/speaker_encoder/umap.png b/speaker_encoder/umap.png new file mode 100644 index 0000000..ca8aefe Binary files /dev/null and b/speaker_encoder/umap.png differ diff --git a/speaker_encoder/utils/__init__.py b/speaker_encoder/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/speaker_encoder/utils/generic_utils.py b/speaker_encoder/utils/generic_utils.py new file mode 100644 index 0000000..1981fbe --- /dev/null +++ b/speaker_encoder/utils/generic_utils.py @@ -0,0 +1,220 @@ +import datetime +import glob +import os +import random +import re +from multiprocessing import Manager + +import numpy as np +from scipy import signal + +from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder +from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder +from TTS.utils.io import save_fsspec + + +class Storage(object): + def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): + # use multiprocessing for threading safe + self.storage = Manager().list() + self.maxsize = maxsize + self.num_speakers_in_batch = num_speakers_in_batch + self.num_threads = num_threads + self.ignore_last_batch = False + + if storage_batchs >= 3: + self.ignore_last_batch = True + + # used for fast random sample + self.safe_storage_size = self.maxsize - self.num_threads + if self.ignore_last_batch: + self.safe_storage_size -= self.num_speakers_in_batch + + def __len__(self): + return len(self.storage) + + def full(self): + return len(self.storage) >= self.maxsize + + def append(self, item): + # if storage is full, remove an item + if self.full(): + self.storage.pop(0) + + self.storage.append(item) + + def get_random_sample(self): + # safe storage size considering all threads remove one item from storage in same time + storage_size = len(self.storage) - self.num_threads + + if self.ignore_last_batch: + storage_size -= self.num_speakers_in_batch + + return self.storage[random.randint(0, storage_size)] + + def get_random_sample_fast(self): + """Call this method only when storage is full""" + return self.storage[random.randint(0, self.safe_storage_size)] + + +class AugmentWAV(object): + def __init__(self, ap, augmentation_config): + + self.ap = ap + self.use_additive_noise = False + + if "additive" in augmentation_config.keys(): + self.additive_noise_config = augmentation_config["additive"] + additive_path = self.additive_noise_config["sounds_path"] + if additive_path: + self.use_additive_noise = True + # get noise types + self.additive_noise_types = [] + for key in self.additive_noise_config.keys(): + if isinstance(self.additive_noise_config[key], dict): + self.additive_noise_types.append(key) + + additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) + + self.noise_list = {} + + for wav_file in additive_files: + noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] + # ignore not listed directories + if noise_dir not in self.additive_noise_types: + continue + if not noise_dir in self.noise_list: + self.noise_list[noise_dir] = [] + self.noise_list[noise_dir].append(wav_file) + + print( + f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + ) + + self.use_rir = False + + if "rir" in augmentation_config.keys(): + self.rir_config = augmentation_config["rir"] + if self.rir_config["rir_path"]: + self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) + self.use_rir = True + + print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + + self.create_augmentation_global_list() + + def create_augmentation_global_list(self): + if self.use_additive_noise: + self.global_noise_list = self.additive_noise_types + else: + self.global_noise_list = [] + if self.use_rir: + self.global_noise_list.append("RIR_AUG") + + def additive_noise(self, noise_type, audio): + + clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) + + noise_list = random.sample( + self.noise_list[noise_type], + random.randint( + self.additive_noise_config[noise_type]["min_num_noises"], + self.additive_noise_config[noise_type]["max_num_noises"], + ), + ) + + audio_len = audio.shape[0] + noises_wav = None + for noise in noise_list: + noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] + + if noiseaudio.shape[0] < audio_len: + continue + + noise_snr = random.uniform( + self.additive_noise_config[noise_type]["min_snr_in_db"], + self.additive_noise_config[noise_type]["max_num_noises"], + ) + noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) + noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio + + if noises_wav is None: + noises_wav = noise_wav + else: + noises_wav += noise_wav + + # if all possible files is less than audio, choose other files + if noises_wav is None: + return self.additive_noise(noise_type, audio) + + return audio + noises_wav + + def reverberate(self, audio): + audio_len = audio.shape[0] + + rir_file = random.choice(self.rir_files) + rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) + rir = rir / np.sqrt(np.sum(rir ** 2)) + return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] + + def apply_one(self, audio): + noise_type = random.choice(self.global_noise_list) + if noise_type == "RIR_AUG": + return self.reverberate(audio) + + return self.additive_noise(noise_type, audio) + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(c): + if c.model_params["model_name"].lower() == "lstm": + model = LSTMSpeakerEncoder( + c.model_params["input_dim"], + c.model_params["proj_dim"], + c.model_params["lstm_dim"], + c.model_params["num_lstm_layers"], + ) + elif c.model_params["model_name"].lower() == "resnet": + model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) + return model + + +def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): + checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) + checkpoint_path = os.path.join(out_path, checkpoint_path) + print(" | | > Checkpoint saving : {}".format(checkpoint_path)) + + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "criterion": criterion.state_dict(), + "step": current_step, + "epoch": epoch, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + save_fsspec(state, checkpoint_path) + + +def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): + if model_loss < best_loss: + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict(), + "criterion": criterion.state_dict(), + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + best_loss = model_loss + bestmodel_path = "best_model.pth.tar" + bestmodel_path = os.path.join(out_path, bestmodel_path) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) + save_fsspec(state, bestmodel_path) + return best_loss diff --git a/speaker_encoder/utils/io.py b/speaker_encoder/utils/io.py new file mode 100644 index 0000000..7a3aadc --- /dev/null +++ b/speaker_encoder/utils/io.py @@ -0,0 +1,38 @@ +import datetime +import os + +from TTS.utils.io import save_fsspec + + +def save_checkpoint(model, optimizer, model_loss, out_path, current_step): + checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) + checkpoint_path = os.path.join(out_path, checkpoint_path) + print(" | | > Checkpoint saving : {}".format(checkpoint_path)) + + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + save_fsspec(state, checkpoint_path) + + +def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): + if model_loss < best_loss: + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict(), + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + best_loss = model_loss + bestmodel_path = "best_model.pth.tar" + bestmodel_path = os.path.join(out_path, bestmodel_path) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) + save_fsspec(state, bestmodel_path) + return best_loss diff --git a/speaker_encoder/utils/prepare_voxceleb.py b/speaker_encoder/utils/prepare_voxceleb.py new file mode 100644 index 0000000..b93baf9 --- /dev/null +++ b/speaker_encoder/utils/prepare_voxceleb.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Only support eager mode and TF>=2.0.0 +# pylint: disable=no-member, invalid-name, relative-beyond-top-level +# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes +""" voxceleb 1 & 2 """ + +import hashlib +import os +import subprocess +import sys +import zipfile + +import pandas +import soundfile as sf +from absl import logging + +SUBSETS = { + "vox1_dev_wav": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], +} + +MD5SUM = { + "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", + "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", + "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", +} + +USER = {"user": "", "password": ""} + +speaker_id_dict = {} + + +def download_and_extract(directory, subset, urls): + """Download and extract the given split of dataset. + + Args: + directory: the directory where to put the downloaded data. + subset: subset name of the corpus. + urls: the list of urls to download the data file. + """ + os.makedirs(directory, exist_ok=True) + + try: + for url in urls: + zip_filepath = os.path.join(directory, url.split("/")[-1]) + if os.path.exists(zip_filepath): + continue + logging.info("Downloading %s to %s" % (url, zip_filepath)) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) + + statinfo = os.stat(zip_filepath) + logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + + # concatenate all parts into zip files + if ".zip" not in zip_filepath: + zip_filepath = "_".join(zip_filepath.split("_")[:-1]) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) + zip_filepath += ".zip" + extract_path = zip_filepath.strip(".zip") + + # check zip file md5sum + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() + if md5 != MD5SUM[subset]: + raise ValueError("md5sum of %s mismatch" % zip_filepath) + + with zipfile.ZipFile(zip_filepath, "r") as zfile: + zfile.extractall(directory) + extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) + finally: + # os.remove(zip_filepath) + pass + + +def exec_cmd(cmd): + """Run a command in a subprocess. + Args: + cmd: command line to be executed. + Return: + int, the return code. + """ + try: + retcode = subprocess.call(cmd, shell=True) + if retcode < 0: + logging.info(f"Child was terminated by signal {retcode}") + except OSError as e: + logging.info(f"Execution failed: {e}") + retcode = -999 + return retcode + + +def decode_aac_with_ffmpeg(aac_file, wav_file): + """Decode a given AAC file into WAV using ffmpeg. + Args: + aac_file: file path to input AAC file. + wav_file: file path to output WAV file. + Return: + bool, True if success. + """ + cmd = f"ffmpeg -i {aac_file} {wav_file}" + logging.info(f"Decoding aac file using command line: {cmd}") + ret = exec_cmd(cmd) + if ret != 0: + logging.error(f"Failed to decode aac file with retcode {ret}") + logging.error("Please check your ffmpeg installation.") + return False + return True + + +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): + """Optionally convert AAC to WAV and make speaker labels. + Args: + input_dir: the directory which holds the input dataset. + subset: the name of the specified subset. e.g. vox1_dev_wav + output_dir: the directory to place the newly generated csv files. + output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv + """ + + logging.info("Preprocessing audio and label for subset %s" % subset) + source_dir = os.path.join(input_dir, subset) + + files = [] + # Convert all AAC file into WAV format. At the same time, generate the csv + for root, _, filenames in os.walk(source_dir): + for filename in filenames: + name, ext = os.path.splitext(filename) + if ext.lower() == ".wav": + _, ext2 = os.path.splitext(name) + if ext2: + continue + wav_file = os.path.join(root, filename) + elif ext.lower() == ".m4a": + # Convert AAC to WAV. + aac_file = os.path.join(root, filename) + wav_file = aac_file + ".wav" + if not os.path.exists(wav_file): + if not decode_aac_with_ffmpeg(aac_file, wav_file): + raise RuntimeError("Audio decoding failed.") + else: + continue + speaker_name = root.split(os.path.sep)[-2] + if speaker_name not in speaker_id_dict: + num = len(speaker_id_dict) + speaker_id_dict[speaker_name] = num + # wav_filesize = os.path.getsize(wav_file) + wav_length = len(sf.read(wav_file)[0]) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) + + # Write to CSV file which contains four columns: + # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". + csv_file_path = os.path.join(output_dir, output_file) + df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df.to_csv(csv_file_path, index=False, sep="\t") + logging.info("Successfully generated csv file {}".format(csv_file_path)) + + +def processor(directory, subset, force_process): + """download and process""" + urls = SUBSETS + if subset not in urls: + raise ValueError(subset, "is not in voxceleb") + + subset_csv = os.path.join(directory, subset + ".csv") + if not force_process and os.path.exists(subset_csv): + return subset_csv + + logging.info("Downloading and process the voxceleb in %s", directory) + logging.info("Preparing subset %s", subset) + download_and_extract(directory, subset, urls[subset]) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") + logging.info("Finished downloading and processing") + return subset_csv + + +if __name__ == "__main__": + logging.set_verbosity(logging.INFO) + if len(sys.argv) != 4: + print("Usage: python prepare_data.py save_directory user password") + sys.exit() + + DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] + for SUBSET in SUBSETS: + processor(DIR, SUBSET, False) diff --git a/speaker_encoder/utils/visual.py b/speaker_encoder/utils/visual.py new file mode 100644 index 0000000..4f40f68 --- /dev/null +++ b/speaker_encoder/utils/visual.py @@ -0,0 +1,46 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import umap + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=np.float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_utter_per_speaker): + embeddings = embeddings[: 10 * num_utter_per_speaker] + model = umap.UMAP() + projection = model.fit_transform(embeddings) + num_speakers = embeddings.shape[0] // num_utter_per_speaker + ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) + colors = [colormap[i] for i in ground_truth] + + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/audio.py b/utils/audio.py new file mode 100644 index 0000000..99e00b6 --- /dev/null +++ b/utils/audio.py @@ -0,0 +1,822 @@ +from typing import Dict, Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy.io.wavfile +import scipy.signal +import soundfile as sf +import torch +from torch import nn + +class StandardScaler: + """StandardScaler for mean-scale normalization with the given mean and scale values.""" + + def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: + self.mean_ = mean + self.scale_ = scale + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, "mean_") + delattr(self, "scale_") + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + TODO: Merge this with audio.py + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain + + +# pylint: disable=too-many-public-methods +class AudioProcessor(object): + """Audio Processor for TTS used by all the data pipelines. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms.. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + + verbose (bool, optional): + enable/disable logging. Defaults to True. + + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + stats_path=None, + verbose=True, + **_, + ): + + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.hop_length, self.win_length = self._stft_parameters() + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert self.win_length <= self.fft_size, " [!] win_length cannot be larger than fft_size" + members = vars(self) + if verbose: + print(" > Setting up Audio Processor...") + for key, value in members.items(): + print(" | > {}:{}".format(key, value)) + # create spectrogram utils + self.mel_basis = self._build_mel_basis() + self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + ### setting up the parameters ### + def _build_mel_basis( + self, + ) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if self.mel_fmax is not None: + assert self.mel_fmax <= self.sample_rate // 2 + return librosa.filters.mel( + self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + + def _stft_parameters( + self, + ) -> Tuple[int, int]: + """Compute the real STFT parameters from the time values. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = self.frame_length_ms / self.frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) + win_length = int(hop_length * factor) + return hop_length, win_length + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### DB and AMP conversion ### + # pylint: disable=no-self-use + def _amp_to_db(self, x: np.ndarray) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + + Returns: + np.ndarray: Decibels spectrogram. + """ + return self.spec_gain * _log(np.maximum(1e-5, x), self.base) + + # pylint: disable=no-self-use + def _db_to_amp(self, x: np.ndarray) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / self.spec_gain, self.base) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -self.preemphasis], [1], x) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -self.preemphasis], x) + + ### SPECTROGRAMs ### + def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: + """Project a full scale spectrogram to a melspectrogram. + + Args: + spectrogram (np.ndarray): Full scale spectrogram. + + Returns: + np.ndarray: Melspectrogram + """ + return np.dot(self.mel_basis, spectrogram) + + def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) + + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_linear: + S = self._amp_to_db(np.abs(D)) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_mel: + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) + else: + S = self._linear_to_mel(np.abs(D)) + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = self._db_to_amp(S) + # Reconstruct phase + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = self._db_to_amp(D) + S = self._mel_to_linear(S) # Convert back to linear + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = self._db_to_amp(S) + S = self._linear_to_mel(np.abs(S)) + S = self._amp_to_db(S) + mel = self.normalize(S) + return mel + + ### STFT and ISTFT ### + def _stft(self, y: np.ndarray) -> np.ndarray: + """Librosa STFT wrapper. + + Args: + y (np.ndarray): Audio signal. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + window="hann", + center=True, + ) + + def _istft(self, y: np.ndarray) -> np.ndarray: + """Librosa iSTFT wrapper.""" + return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) + + def _griffin_lim(self, S): + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = self._istft(S_complex * angles) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(self.griffin_lim_iters): + angles = np.exp(1j * np.angle(self._stft(y))) + y = self._istft(S_complex * angles) + return y + + def compute_stft_paddings(self, x, pad_sides=1): + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + assert pad_sides in (1, 2) + pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] + if pad_sides == 1: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + f0, t = pw.dio( + x.astype(np.double), + fs=self.sample_rate, + f0_ceil=self.mel_fmax, + frame_period=1000 * self.hop_length / self.sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + # pad = int((self.win_length / self.hop_length) / 2) + # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) + # f0 = np.array(f0, dtype=np.float32) + + # f01, _, _ = librosa.pyin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate, + # fill_na=0.0, + # ) + + # spec = self.melspectrogram(x) + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + window_length = int(self.sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = self._db_to_amp(threshold_db) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(self.sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ + 0 + ] + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * 0.95 + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if self.resample: + x, sr = librosa.load(filename, sr=self.sample_rate) + elif sr is None: + x, sr = sf.read(filename) + assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) + else: + x, sr = librosa.load(filename, sr=sr) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(f" [!] File cannot be trimmed for silence - {filename}") + if self.do_sound_norm: + x = self.sound_norm(x) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16)) + + @staticmethod + def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: + mu = 2 ** qc - 1 + # wav_abs = np.minimum(np.abs(wav), 1.0) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + # Quantize signal to the specified number of levels. + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + @staticmethod + def mulaw_decode(wav, qc): + """Recovers waveform from quantized values.""" + mu = 2 ** qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + @staticmethod + def encode_16bits(x): + return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) + + @staticmethod + def quantize(x: np.ndarray, bits: int) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2 ** bits - 1) / 2 + + @staticmethod + def dequantize(x, bits): + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2 ** bits - 1) - 1 + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) diff --git a/utils/io.py b/utils/io.py new file mode 100644 index 0000000..e4a068c --- /dev/null +++ b/utils/io.py @@ -0,0 +1,198 @@ +import datetime +import json +import os +import pickle as pickle_tts +import shutil +from typing import Any, Callable, Dict, Union + +import fsspec +import torch +from coqpit import Coqpit + + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + + def find_class(self, module, name): + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) + + +class AttrDict(dict): + """A custom dict which converts dict keys + to class attributes""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def copy_model_files(config: Coqpit, out_path, new_fields): + """Copy config.json and other model files to training folder and add + new fields. + + Args: + config (Coqpit): Coqpit config defining the training run. + out_path (str): output path to copy the file. + new_fields (dict): new fileds to be added or edited + in the config file. + """ + copy_config_path = os.path.join(out_path, "config.json") + # add extra information fields + config.update(new_fields, allow_new=True) + # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. + with fsspec.open(copy_config_path, "w", encoding="utf8") as f: + json.dump(config.to_dict(), f, indent=4) + + # copy model stats file if available + if config.audio.stats_path is not None: + copy_stats_path = os.path.join(out_path, "scale_stats.npy") + filesystem = fsspec.get_mapper(copy_stats_path).fs + if not filesystem.exists(copy_stats_path): + with fsspec.open(config.audio.stats_path, "rb") as source_file: + with fsspec.open(copy_stats_path, "wb") as target_file: + shutil.copyfileobj(source_file, target_file) + + +def load_fsspec( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + **kwargs, +) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + map_location: torch.device or str. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) + + +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin + try: + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state + + +def save_fsspec(state: Any, path: str, **kwargs): + """Like torch.save but can save to other locations (e.g. s3:// , gs://). + + Args: + state: State object to save + path: Any path or url supported by fsspec. + **kwargs: Keyword arguments forwarded to torch.save. + """ + with fsspec.open(path, "wb") as f: + torch.save(state, f, **kwargs) + + +def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): + if hasattr(model, "module"): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + if isinstance(optimizer, list): + optimizer_state = [optim.state_dict() for optim in optimizer] + else: + optimizer_state = optimizer.state_dict() if optimizer is not None else None + + if isinstance(scaler, list): + scaler_state = [s.state_dict() for s in scaler] + else: + scaler_state = scaler.state_dict() if scaler is not None else None + + if isinstance(config, Coqpit): + config = config.to_dict() + + state = { + "config": config, + "model": model_state, + "optimizer": optimizer_state, + "scaler": scaler_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + state.update(kwargs) + save_fsspec(state, output_path) + + +def save_checkpoint( + config, + model, + optimizer, + scaler, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth.tar".format(current_step) + checkpoint_path = os.path.join(output_folder, file_name) + print("\n > CHECKPOINT : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) + + +def save_best_model( + current_loss, + best_loss, + config, + model, + optimizer, + scaler, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): + if current_loss < best_loss: + best_model_name = f"best_model_{current_step}.pth.tar" + checkpoint_path = os.path.join(out_path, best_model_name) + print(" > BEST MODEL : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) + fs = fsspec.get_mapper(out_path).fs + # only delete previous if current is saved successfully + if not keep_all_best or (current_step < keep_after): + model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar")) + for model_name in model_names: + if os.path.basename(model_name) != best_model_name: + fs.rm(model_name) + # create a shortcut which always points to the currently best model + shortcut_name = "best_model.pth.tar" + shortcut_path = os.path.join(out_path, shortcut_name) + fs.copy(checkpoint_path, shortcut_path) + best_loss = current_loss + return best_loss diff --git a/vi_speaker_batch.py b/vi_speaker_batch.py new file mode 100644 index 0000000..503c739 --- /dev/null +++ b/vi_speaker_batch.py @@ -0,0 +1,88 @@ +import os +import re +import json +import fsspec +import torch +import numpy as np +import argparse + +from tqdm import tqdm +from argparse import RawTextHelpFormatter +from speaker_encoder.models.lstm import LSTMSpeakerEncoder +from speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig + +from utils.audio import AudioProcessor +from vi_speaker_single import read_json + + +def get_spk_wavs(dataset_path, output_path): + wav_files = [] + os.makedirs(f"./{output_path}") + for spks in os.listdir(dataset_path): + if os.path.isdir(f"./{dataset_path}/{spks}"): + os.makedirs(f"./{output_path}/{spks}") + for file in os.listdir(f"./{dataset_path}/{spks}"): + if file.endswith(".wav"): + wav_files.append(f"./{dataset_path}/{spks}/{file}") + elif spks.endswith(".wav"): + wav_files.append(f"./{dataset_path}/{spks}") + return wav_files + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each wav file in a dataset.""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument("config_path", type=str, help="Path to model config file.") + parser.add_argument("dataset_path", type=str, help="Path to dataset waves.") + parser.add_argument( + "output_path", type=str, help="path for output speaker/speaker_wavs.npy." + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + args = parser.parse_args() + dataset_path = args.dataset_path + output_path = args.output_path + + # config + config_dict = read_json(args.config_path) + + # model + config = SpeakerEncoderConfig(config_dict) + config.from_dict(config_dict) + + speaker_encoder = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + ) + + speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) + + # preprocess + speaker_encoder_ap = AudioProcessor(**config.audio) + # normalize the input audio level and trim silences + speaker_encoder_ap.do_sound_norm = True + speaker_encoder_ap.do_trim_silence = True + + wav_files = get_spk_wavs(dataset_path, output_path) + + # compute speaker embeddings + for idx, wav_file in enumerate(tqdm(wav_files)): + waveform = speaker_encoder_ap.load_wav( + wav_file, sr=speaker_encoder_ap.sample_rate + ) + spec = speaker_encoder_ap.melspectrogram(waveform) + spec = torch.from_numpy(spec.T) + if args.use_cuda: + spec = spec.cuda() + spec = spec.unsqueeze(0) + embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() + embed = embed.squeeze() + embed_path = wav_file.replace(dataset_path, output_path) + embed_path = embed_path.replace(".wav", ".npy") + np.save(embed_path, embed, allow_pickle=False) diff --git a/vi_speaker_center.py b/vi_speaker_center.py new file mode 100644 index 0000000..62a1e44 --- /dev/null +++ b/vi_speaker_center.py @@ -0,0 +1,21 @@ +import os +import numpy as np + +single_id_path = "speaker_embedding" +center_id_path = "speaker_embedding_center" + +os.makedirs(f"./{center_id_path}") + +for speaker in os.listdir(single_id_path): + if os.path.isdir(f"./{single_id_path}/{speaker}"): + print(f"---->{speaker}<----") + subfile_num = 0 + speaker_cen = 0 + for file in os.listdir(f"./{single_id_path}/{speaker}"): + if file.endswith(".npy"): + source_embed = np.load(f"./{single_id_path}/{speaker}/{file}") + source_embed = source_embed.astype(np.float32) + speaker_cen = speaker_cen + source_embed + subfile_num = subfile_num + 1 + speaker_cen = speaker_cen / subfile_num + np.save(f"./{center_id_path}/{speaker}.npy", speaker_cen, allow_pickle=False) diff --git a/vi_speaker_single.py b/vi_speaker_single.py new file mode 100644 index 0000000..7260f53 --- /dev/null +++ b/vi_speaker_single.py @@ -0,0 +1,109 @@ +import re +import json +import fsspec +import torch +import numpy as np +import argparse + +from argparse import RawTextHelpFormatter +from speaker_encoder.models.lstm import LSTMSpeakerEncoder +from speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig + +from utils.audio import AudioProcessor + + +def read_json(json_path): + config_dict = {} + try: + with fsspec.open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(json_path) + config_dict.update(data) + return config_dict + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments + input_str = re.sub(r"\\\n", "", input_str) + input_str = re.sub(r"//.*\n", "\n", input_str) + data = json.loads(input_str) + return data + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each wav file in a dataset.""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument("-s", "--source", help="input wave", dest="source") + parser.add_argument( + "-t", "--target", help="output 256d speaker embeddimg", dest="target" + ) + + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + + args = parser.parse_args() + source_file = args.source + target_file = args.target + + # config + config_dict = read_json(args.config_path) + # print(config_dict) + + # model + config = SpeakerEncoderConfig(config_dict) + config.from_dict(config_dict) + + speaker_encoder = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + ) + + speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) + + # preprocess + speaker_encoder_ap = AudioProcessor(**config.audio) + # normalize the input audio level and trim silences + speaker_encoder_ap.do_sound_norm = True + speaker_encoder_ap.do_trim_silence = True + + # compute speaker embeddings + + # extract the embedding + waveform = speaker_encoder_ap.load_wav( + source_file, sr=speaker_encoder_ap.sample_rate + ) + spec = speaker_encoder_ap.melspectrogram(waveform) + spec = torch.from_numpy(spec.T) + if args.use_cuda: + spec = spec.cuda() + spec = spec.unsqueeze(0) + embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() + embed = embed.squeeze() + # print(embed) + # print(embed.size) + np.save(target_file, embed, allow_pickle=False) + + + if hasattr(speaker_encoder, 'module'): + state_dict = speaker_encoder.module.state_dict() + else: + state_dict = speaker_encoder.state_dict() + torch.save({'model': state_dict}, "model_small.pth")