Skip to content

Commit

Permalink
fix: parallel inference
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyizhu committed Sep 10, 2024
1 parent 1e39fdc commit 6feb586
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 101 deletions.
68 changes: 49 additions & 19 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,15 @@ class RefineTextParams:
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True
manual_seed: Optional[int] = 0
manual_seed: Optional[int] = None

@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
prompt: str = "[speed_5]"
spk_emb: Optional[str] = None
spk_smp: Optional[str] = None
txt_smp: Optional[str] = None
top_P: float = 1
top_K: int = 1
temperature: float = 0.01
temperature: float = 0.3
repetition_penalty: float = 1.05
max_new_token: int = 2048
stream_batch: int = 24
Expand All @@ -196,13 +194,13 @@ def infer(
text,
stream=False,
lang=None,
skip_refine_text=True,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
params_refine_text=None,
params_infer_code=None,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
stream_batch_size=16,
):
self.context.set(False)
Expand Down Expand Up @@ -273,7 +271,7 @@ def _load(
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=device,
device=self.device,
)
.to(device)
.eval()
Expand All @@ -290,8 +288,8 @@ def _load(
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(embed_path, device=device)
self.embed = embed.to(device)
embed.from_pretrained(embed_path, device=self.device)
self.embed = embed.to(self.device)
self.logger.log(logging.INFO, "embed loaded.")

gpt = GPT(
Expand Down Expand Up @@ -343,15 +341,15 @@ def _load(
async def _infer(
self,
text,
stream=True,
stream=False,
lang=None,
skip_refine_text=True,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
params_refine_text=None,
params_infer_code=None,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
stream_batch_size=16,
):

Expand Down Expand Up @@ -399,13 +397,11 @@ async def _infer(
result.hiddens if use_decoder else result.ids,
use_decoder,
)

if result.finished:
yield wavs[:, length:]
else:
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
import librosa

silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
silence_left = 0
if len(silence_intervals) == 0:
Expand Down Expand Up @@ -504,8 +500,8 @@ async def _infer_code(
repetition_penalty=params.repetition_penalty,
)

speaker_embedding_param = gpt(input_ids, text_mask)

speaker_embedding_param = self.embed(input_ids, text_mask)
del text_mask
if params.spk_emb is not None:
self.speaker.apply(
speaker_embedding_param,
Expand Down Expand Up @@ -536,7 +532,7 @@ async def _infer_code(
async for i in results_generator:
token_ids = []
hidden_states = []
if len(i.outputs[0].token_ids) % stream_batch_size == 0 or i.finished:
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
Expand All @@ -547,6 +543,40 @@ async def _infer_code(
hiddens=hidden_states,
attentions=[],
)
else:
results_generator = gpt.generate(
speaker_embedding_param,
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)
del speaker_embedding_param, input_ids
async for i in results_generator:
token_ids = []
hidden_states = []
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
token_ids.append(i.ids[0])
hidden_states.append(
i.hiddens[0].to(torch.float32).to(self.device)
)
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
hiddens=hidden_states,
attentions=[],
)

@torch.no_grad()
def _refine_text(
Expand Down
55 changes: 11 additions & 44 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ def __init__(
self.is_te_llama = False
self.is_vllm = use_vllm

self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]
if self.is_vllm:
return

self.llama_config = self._build_llama_config(gpt_config)

self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]

def from_pretrained(
self, gpt_folder: str, embed_file_path: str, experimental=False
):
Expand All @@ -67,7 +68,7 @@ def from_pretrained(
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
dtype="float32",
dtype="float32"
)
self.logger.info("vLLM model loaded")
return
Expand Down Expand Up @@ -138,44 +139,6 @@ def prepare(self, compile=False):
except RuntimeError as e:
self.logger.warning(f"compile failed: {e}. fallback to normal mode.")

def __call__(
self, input_ids: torch.Tensor, text_mask: torch.Tensor
) -> torch.Tensor:
"""
get_emb
"""
return super().__call__(input_ids, text_mask)

def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
"""
get_emb
"""
input_ids = input_ids.clone()
text_mask = text_mask.clone()
emb_text: torch.Tensor = self.emb_text(
input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(self.device_gpt)
)

text_mask_inv = text_mask.logical_not().to(self.device_gpt)
masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(self.device_gpt)

emb_code = [
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
]
emb_code = torch.stack(emb_code, 2).sum(2)

emb = torch.zeros(
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
device=emb_text.device,
dtype=emb_text.dtype,
)
emb[text_mask] = emb_text
emb[text_mask_inv] = emb_code.to(emb.dtype)

del emb_text, emb_code, text_mask_inv

return emb

@dataclass(repr=False, eq=False)
class _GenerationInputs:
position_ids: torch.Tensor
Expand Down Expand Up @@ -327,6 +290,7 @@ def _prepare_generation_outputs(
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
hiddens: List[torch.Tensor],
infer_text: bool,
finished: bool,
) -> GenerationOutputs:
inputs_ids = [
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
Expand All @@ -344,10 +308,11 @@ def _prepare_generation_outputs(
ids=inputs_ids,
attentions=attentions,
hiddens=hiddens,
finished=finished,
)

@torch.no_grad()
def generate(
async def generate(
self,
emb: torch.Tensor,
inputs_ids: torch.Tensor,
Expand Down Expand Up @@ -620,6 +585,7 @@ def generate(
attentions,
hiddens,
infer_text,
False
)
del not_finished

Expand Down Expand Up @@ -649,4 +615,5 @@ def generate(
attentions,
hiddens,
infer_text,
True
)
1 change: 0 additions & 1 deletion ChatTTS/model/velocity/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .async_llm_engine import AsyncLLMEngine
from .configs import EngineArgs
from .llm_engine import LLMEngine
from .output import RequestOutput
from .sampling_params import SamplingParams

Expand Down
1 change: 0 additions & 1 deletion ChatTTS/model/velocity/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
import numpy as np
import torch

if ray:
Expand Down
Loading

0 comments on commit 6feb586

Please sign in to comment.