From 6feb58695f73fec9dff74548c5cb597283f20ebb Mon Sep 17 00:00:00 2001 From: zhufengyi Date: Tue, 10 Sep 2024 17:11:34 +0800 Subject: [PATCH] fix: parallel inference --- ChatTTS/core.py | 68 +++++++++++++----- ChatTTS/model/gpt.py | 55 +++------------ ChatTTS/model/velocity/llm.py | 1 - ChatTTS/model/velocity/llm_engine.py | 1 - ChatTTS/model/velocity/model_runner.py | 84 +++++++++++++++------- examples/api/main.py | 29 +++++--- examples/api/openai.py | 97 ++++++++++++++++++++++++++ 7 files changed, 234 insertions(+), 101 deletions(-) create mode 100644 examples/api/openai.py diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 3c18a1604..ceabd70fa 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -174,7 +174,7 @@ class RefineTextParams: min_new_token: int = 0 show_tqdm: bool = True ensure_non_empty: bool = True - manual_seed: Optional[int] = 0 + manual_seed: Optional[int] = None @dataclass(repr=False, eq=False) class InferCodeParams(RefineTextParams): @@ -182,9 +182,7 @@ class InferCodeParams(RefineTextParams): spk_emb: Optional[str] = None spk_smp: Optional[str] = None txt_smp: Optional[str] = None - top_P: float = 1 - top_K: int = 1 - temperature: float = 0.01 + temperature: float = 0.3 repetition_penalty: float = 1.05 max_new_token: int = 2048 stream_batch: int = 24 @@ -196,13 +194,13 @@ def infer( text, stream=False, lang=None, - skip_refine_text=True, + skip_refine_text=False, refine_text_only=False, use_decoder=True, do_text_normalization=True, do_homophone_replacement=True, - params_refine_text=None, - params_infer_code=None, + params_refine_text=RefineTextParams(), + params_infer_code=InferCodeParams(), stream_batch_size=16, ): self.context.set(False) @@ -273,7 +271,7 @@ def _load( vq_config=asdict(self.config.dvae.vq), dim=self.config.dvae.decoder.idim, coef=coef, - device=device, + device=self.device, ) .to(device) .eval() @@ -290,8 +288,8 @@ def _load( self.config.embed.num_text_tokens, self.config.embed.num_vq, ) - embed.from_pretrained(embed_path, device=device) - self.embed = embed.to(device) + embed.from_pretrained(embed_path, device=self.device) + self.embed = embed.to(self.device) self.logger.log(logging.INFO, "embed loaded.") gpt = GPT( @@ -343,15 +341,15 @@ def _load( async def _infer( self, text, - stream=True, + stream=False, lang=None, - skip_refine_text=True, + skip_refine_text=False, refine_text_only=False, use_decoder=True, do_text_normalization=True, do_homophone_replacement=True, - params_refine_text=None, - params_infer_code=None, + params_refine_text=RefineTextParams(), + params_infer_code=InferCodeParams(), stream_batch_size=16, ): @@ -399,13 +397,11 @@ async def _infer( result.hiddens if use_decoder else result.ids, use_decoder, ) - if result.finished: yield wavs[:, length:] else: # Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop. import librosa - silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10) silence_left = 0 if len(silence_intervals) == 0: @@ -504,8 +500,8 @@ async def _infer_code( repetition_penalty=params.repetition_penalty, ) - speaker_embedding_param = gpt(input_ids, text_mask) - + speaker_embedding_param = self.embed(input_ids, text_mask) + del text_mask if params.spk_emb is not None: self.speaker.apply( speaker_embedding_param, @@ -536,7 +532,7 @@ async def _infer_code( async for i in results_generator: token_ids = [] hidden_states = [] - if len(i.outputs[0].token_ids) % stream_batch_size == 0 or i.finished: + if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished: token_ids.append(torch.tensor(i.outputs[0].token_ids)) hidden_states.append( i.outputs[0].hidden_states.to(torch.float32).to(self.device) @@ -547,6 +543,40 @@ async def _infer_code( hiddens=hidden_states, attentions=[], ) + else: + results_generator = gpt.generate( + speaker_embedding_param, + input_ids, + temperature=torch.tensor(temperature, device=device), + eos_token=num_code, + attention_mask=attention_mask, + max_new_token=params.max_new_token, + min_new_token=params.min_new_token, + logits_processors=(*logits_processors, *logits_warpers), + infer_text=False, + return_hidden=return_hidden, + stream=stream, + show_tqdm=params.show_tqdm, + ensure_non_empty=params.ensure_non_empty, + stream_batch=params.stream_batch, + manual_seed=params.manual_seed, + context=self.context, + ) + del speaker_embedding_param, input_ids + async for i in results_generator: + token_ids = [] + hidden_states = [] + if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished: + token_ids.append(i.ids[0]) + hidden_states.append( + i.hiddens[0].to(torch.float32).to(self.device) + ) + yield GPT.GenerationOutputs( + ids=token_ids, + finished=i.finished, + hiddens=hidden_states, + attentions=[], + ) @torch.no_grad() def _refine_text( diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index f14b4f591..4abeb241b 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -46,15 +46,16 @@ def __init__( self.is_te_llama = False self.is_vllm = use_vllm - self.emb_code = [ec.__call__ for ec in embed.emb_code] - self.emb_text = embed.emb_text.__call__ - self.head_text = embed.head_text.__call__ - self.head_code = [hc.__call__ for hc in embed.head_code] if self.is_vllm: return self.llama_config = self._build_llama_config(gpt_config) + self.emb_code = [ec.__call__ for ec in embed.emb_code] + self.emb_text = embed.emb_text.__call__ + self.head_text = embed.head_text.__call__ + self.head_code = [hc.__call__ for hc in embed.head_code] + def from_pretrained( self, gpt_folder: str, embed_file_path: str, experimental=False ): @@ -67,7 +68,7 @@ def from_pretrained( num_audio_tokens=self.num_audio_tokens, num_text_tokens=self.num_text_tokens, post_model_path=embed_file_path, - dtype="float32", + dtype="float32" ) self.logger.info("vLLM model loaded") return @@ -138,44 +139,6 @@ def prepare(self, compile=False): except RuntimeError as e: self.logger.warning(f"compile failed: {e}. fallback to normal mode.") - def __call__( - self, input_ids: torch.Tensor, text_mask: torch.Tensor - ) -> torch.Tensor: - """ - get_emb - """ - return super().__call__(input_ids, text_mask) - - def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: - """ - get_emb - """ - input_ids = input_ids.clone() - text_mask = text_mask.clone() - emb_text: torch.Tensor = self.emb_text( - input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(self.device_gpt) - ) - - text_mask_inv = text_mask.logical_not().to(self.device_gpt) - masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(self.device_gpt) - - emb_code = [ - self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) - ] - emb_code = torch.stack(emb_code, 2).sum(2) - - emb = torch.zeros( - (input_ids.shape[:-1]) + (emb_text.shape[-1],), - device=emb_text.device, - dtype=emb_text.dtype, - ) - emb[text_mask] = emb_text - emb[text_mask_inv] = emb_code.to(emb.dtype) - - del emb_text, emb_code, text_mask_inv - - return emb - @dataclass(repr=False, eq=False) class _GenerationInputs: position_ids: torch.Tensor @@ -327,6 +290,7 @@ def _prepare_generation_outputs( attentions: List[Optional[Tuple[torch.FloatTensor, ...]]], hiddens: List[torch.Tensor], infer_text: bool, + finished: bool, ) -> GenerationOutputs: inputs_ids = [ inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx) @@ -344,10 +308,11 @@ def _prepare_generation_outputs( ids=inputs_ids, attentions=attentions, hiddens=hiddens, + finished=finished, ) @torch.no_grad() - def generate( + async def generate( self, emb: torch.Tensor, inputs_ids: torch.Tensor, @@ -620,6 +585,7 @@ def generate( attentions, hiddens, infer_text, + False ) del not_finished @@ -649,4 +615,5 @@ def generate( attentions, hiddens, infer_text, + True ) diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index b9b351be9..81bb65fbf 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -6,7 +6,6 @@ from .async_llm_engine import AsyncLLMEngine from .configs import EngineArgs -from .llm_engine import LLMEngine from .output import RequestOutput from .sampling_params import SamplingParams diff --git a/ChatTTS/model/velocity/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py index 7a3f2b40f..d878c7795 100644 --- a/ChatTTS/model/velocity/llm_engine.py +++ b/ChatTTS/model/velocity/llm_engine.py @@ -22,7 +22,6 @@ ) from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port -import numpy as np import torch if ray: diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 2e45dda65..afa438a5f 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn as nn +from torch import Tensor from .configs import ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger @@ -22,6 +23,7 @@ SequenceOutput, ) from vllm.utils import in_wsl + from ..embed import Embed from .sampler import Sampler from safetensors.torch import safe_open @@ -105,11 +107,12 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: + ) -> tuple[list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + embedding: List[torch.Tensor] = [] prompt_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: @@ -127,7 +130,7 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append(list(range(prompt_len))) - + embedding.append(seq_group_metadata.speaker_embedding_param) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -166,6 +169,10 @@ def _prepare_prompt( slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long ) + embedding = _make_with_pad( + embedding, max_prompt_len, pad=0, dtype=torch.float32 + ) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, @@ -174,7 +181,7 @@ def _prepare_prompt( block_tables=None, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata, prompt_lens + return input_tokens, input_positions, input_metadata, prompt_lens, embedding def _prepare_decode( self, @@ -353,14 +360,15 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]]: + speaker_embedding = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens) = ( + (input_tokens, input_positions, input_metadata, prompt_lens, speaker_embedding) = ( self._prepare_prompt(seq_group_metadata_list) ) else: @@ -454,7 +462,7 @@ def get_size_or_none(x: Optional[torch.Tensor]): perform_sampling=False, ) - return input_tokens, input_positions, input_metadata, sampling_metadata + return input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding @torch.inference_mode() def execute_model( @@ -462,7 +470,8 @@ def execute_model( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - input_tokens, input_positions, input_metadata, sampling_metadata = ( + + input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding = ( self.prepare_input_tensors(seq_group_metadata_list) ) # print(sampling_metadata.seq_data) @@ -493,10 +502,10 @@ def execute_model( # print(logits_processors, logits_warpers) min_new_token = sampling_metadata.seq_groups[0][1].min_new_token eos_token = sampling_metadata.seq_groups[0][1].eos_token - start_idx = sampling_metadata.seq_groups[0][1].start_idx + start_idx = input_tokens[0].shape[0] if input_tokens.shape[-2] == 1: if infer_text: - input_emb: torch.Tensor = self.post_model.emb_text( + speaker_embedding_params: torch.Tensor = self.post_model.emb_text( input_tokens[:, :, 0] ) else: @@ -504,17 +513,22 @@ def execute_model( self.post_model.emb_code[i](input_tokens[:, :, i]) for i in range(self.post_model.num_vq) ] - input_emb = torch.stack(code_emb, 3).sum(3) + speaker_embedding_params = torch.stack(code_emb, 3).sum(3) else: - speaker_embedding_param = seq_group_metadata_list[0].speaker_embedding_param - input_emb = ( - speaker_embedding_param - if speaker_embedding_param is not None - else self.post_model(input_tokens, text_mask) - ) - # print(input_emb.shape) + # 通过for循环,拼接成一个tensor + if seq_group_metadata_list[0].speaker_embedding_param is not None: + speaker_embedding_params = None + for i in range(input_tokens.shape[0]): + if speaker_embedding_params is None: + speaker_embedding_params = speaker_embedding[i] + else: + speaker_embedding_params = torch.cat((speaker_embedding_params, speaker_embedding[i])) + + else: + speaker_embedding_params = self.post_model(input_tokens, text_mask) + hidden_states = model_executable( - input_emb=input_emb, + input_emb=speaker_embedding_params, positions=input_positions, kv_caches=kv_caches, input_metadata=input_metadata, @@ -524,7 +538,7 @@ def execute_model( input_tokens = input_tokens[:, :, :] hidden_states = hidden_states[:, :, :] idx_next, logprob, finish = self.sampler.sample( - inputs_ids=(input_tokens), + inputs_ids=input_tokens, hidden_states=hidden_states, infer_text=infer_text, temperature=temperture, @@ -546,7 +560,7 @@ def execute_model( # sampling_metadata=sampling_metadata, # ) results = [] - for i in range(idx_next.shape[0]): + for i,val in enumerate(seq_groups): idx_next_i = idx_next[i, 0, :].tolist() logprob_i = logprob[i].tolist() tmp_hidden_states = hidden_states[i] @@ -592,7 +606,7 @@ def profile_run(self) -> None: is_prompt=True, seq_data={group_id: seq_data}, sampling_params=sampling_params, - speaker_embedding_param=None, + speaker_embedding_param=torch.zeros(1, seq_len, 768).to("cuda"), block_tables=None, ) seqs.append(seq) @@ -748,11 +762,11 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: +def _pad_to_max(x: List[int], max_len: int, pad: List[int]) -> List[int]: assert len(x) <= max_len if len(x) == max_len: return list(x) - return list(x) + [pad] * (max_len - len(x)) + return [pad] * (max_len - len(x)) + list(x) def _make_tensor_with_pad( @@ -766,17 +780,35 @@ def _make_tensor_with_pad( padded_x = [] for x_i in x: pad_i = pad - if isinstance(x[0][0], tuple): + if isinstance(x[0][0], list): + pad_i = [0,] * len(x[0][0]) + elif isinstance(x[0][0], tuple): pad_i = (0,) * len(x[0][0]) padded_x.append(_pad_to_max(x_i, max_len, pad_i)) - return torch.tensor( padded_x, dtype=dtype, device=device, - pin_memory=pin_memory and str(device) == "cpu", ) +def _make_with_pad( + x: List[torch.Tensor], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", +) -> torch.Tensor: + padded_x = [] + for x_i in x: + assert x_i.shape[-2] <= max_len + if x_i.shape[-2] == max_len: + padded_x.append(x_i) + else: + padded_x.append( + torch.cat((torch.zeros(1, max_len-x_i.shape[-2], 768).to(device), x_i), dim=1) + ) + + return padded_x def _get_graph_batch_size(batch_size: int) -> int: if batch_size <= 2: diff --git a/examples/api/main.py b/examples/api/main.py index 57a27031f..e3589c2dc 100644 --- a/examples/api/main.py +++ b/examples/api/main.py @@ -17,7 +17,6 @@ import ChatTTS -from tools.audio import pcm_arr_to_mp3_view, float_to_int16 from tools.logger import get_logger import torch @@ -44,11 +43,10 @@ async def startup_event(): class ChatTTSParams(BaseModel): - input: str + text: list[str] stream: bool = False lang: Optional[str] = None - voice: Optional[str] = None - skip_refine_text: bool = True + skip_refine_text: bool = False refine_text_only: bool = False use_decoder: bool = True do_text_normalization: bool = True @@ -58,14 +56,25 @@ class ChatTTSParams(BaseModel): stream_batch_size: int = 16 -@app.post("/v1/audio/speech") -async def speech(params: ChatTTSParams): - logger.info("Text input: %s", str(params.input)) +@app.post("/generate_voice") +async def generate_voice(params: ChatTTSParams): + logger.info("Text input: %s", str(params.text)) - # text seed for text refining + # audio seed + if params.params_infer_code.manual_seed is not None: + torch.manual_seed(params.params_infer_code.manual_seed) + params.params_infer_code.spk_emb = chat.sample_random_speaker() - # no text refining - text = [params.input] + # text seed for text refining + if params.params_refine_text and params.skip_refine_text is False: + results_generator = chat.infer( + text=params.text, skip_refine_text=False, refine_text_only=True + ) + text = await next(results_generator) + logger.info(f"Refined text: {text}") + else: + # no text refining + text = params.text logger.info("Use speaker:") logger.info(params.params_infer_code.spk_emb) diff --git a/examples/api/openai.py b/examples/api/openai.py new file mode 100644 index 000000000..646956616 --- /dev/null +++ b/examples/api/openai.py @@ -0,0 +1,97 @@ +import os +import sys + +import numpy as np +from fastapi import FastAPI +from fastapi.responses import Response, StreamingResponse + +from tools.audio.np import pcm_to_wav_bytes + +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +now_dir = os.getcwd() +sys.path.append(now_dir) + +from typing import Optional, AsyncGenerator + +import ChatTTS + +from tools.logger import get_logger + + +from pydantic import BaseModel + + +logger = get_logger("Command") + +app = FastAPI() + + +@app.on_event("startup") +async def startup_event(): + global chat + + chat = ChatTTS.Chat(get_logger("ChatTTS")) + logger.info("Initializing ChatTTS...") + if chat.load(use_vllm=True): + logger.info("Models loaded successfully.") + else: + logger.error("Models load failed.") + sys.exit(1) + + +class ChatTTSParams(BaseModel): + input: str + stream: bool = False + lang: Optional[str] = None + voice: Optional[str] = None + skip_refine_text: bool = True + refine_text_only: bool = False + use_decoder: bool = True + do_text_normalization: bool = True + do_homophone_replacement: bool = False + params_refine_text: Optional[ChatTTS.Chat.RefineTextParams] = None + params_infer_code: Optional[ChatTTS.Chat.InferCodeParams] = None + stream_batch_size: int = 16 + + +@app.post("/v1/audio/speech") +async def speech(params: ChatTTSParams): + logger.info("Text input: %s", str(params.input)) + text = [params.input] + logger.info("Use speaker:") + logger.info(params.params_infer_code.spk_emb) + logger.info("Start voice inference.") + results_generator = chat.infer( + text=text, + stream=params.stream, + lang=params.lang, + skip_refine_text=params.skip_refine_text, + use_decoder=params.use_decoder, + do_text_normalization=params.do_text_normalization, + do_homophone_replacement=params.do_homophone_replacement, + params_infer_code=params.params_infer_code, + params_refine_text=params.params_refine_text, + ) + + if params.stream: + + async def stream_results() -> AsyncGenerator[bytes, None]: + async for output in results_generator: + yield pcm_to_wav_bytes(output[0]) + + return StreamingResponse( + content=stream_results(), media_type="text/event-stream" + ) + + output = None + async for request_output in results_generator: + if output is None: + output = request_output[0] + else: + output = np.concatenate((output, request_output[0]), axis=0) + output = pcm_to_wav_bytes(output) + return Response( + content=output, media_type="audio/wav", headers={"Cache-Control": "no-cache"} + )