diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 192b8bc86f65..f71be0b15a04 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -33,16 +33,13 @@ from nemo.deploy import ITritonDeployable from nemo.export.tarutils import TarPath, unpack_tarball from nemo.export.trt_llm.converter.model_converter import determine_quantization_settings, model_to_trtllm_ckpt -from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import ( - dist_model_to_trt_llm_ckpt, - get_layer_prefix, - torch_dtype_from_precision, -) +from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt, get_layer_prefix from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( build_tokenizer, get_model_type, get_tokenizer, + get_weights_dtype, is_nemo_file, load_nemo_model, ) @@ -59,6 +56,7 @@ unload_engine, ) from nemo.export.trt_llm.utils import is_rank +from nemo.export.utils import torch_dtype_from_precision use_deploy = True try: @@ -170,7 +168,7 @@ def export( paged_kv_cache: bool = True, remove_input_padding: bool = True, paged_context_fmha: bool = False, - dtype: str = "bfloat16", + dtype: Optional[str] = None, load_model: bool = True, use_lora_plugin: str = None, lora_target_modules: List[str] = None, @@ -208,7 +206,8 @@ def export( paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM. paged_context_fmha (bool): whether to use paged context fmha feature of TRT-LLM or not remove_input_padding (bool): enables removing input padding or not. - dtype (str): Floating point type for model weights (Supports BFloat16/Float16). + dtype (Optional[str]): Floating point type for model weights (supports 'bfloat16', 'float16' or 'float32'). + If None, try to autodetect the type from model config. load_model (bool): load TensorRT-LLM model after the export. use_lora_plugin (str): use dynamic lora or not. lora_target_modules (List[str]): list of the target lora modules. @@ -316,12 +315,24 @@ def export( model_type = get_model_type(nemo_checkpoint_path) if model_type is None: - raise Exception("model_type needs to be specified, got None.") + raise ValueError( + "Parameter model_type needs to be provided and cannot be inferred from the checkpoint. " + "Please specify it explicitely." + ) if model_type not in self.get_supported_models_list: - raise Exception( - "Model {0} is not currently a supported model type. " - "Supported model types are: {1}.".format(model_type, self.get_supported_models_list) + raise ValueError( + f"Model {model_type} is not currently a supported model type. " + f"Supported model types are: {self.get_supported_models_list}." + ) + + if dtype is None: + dtype = get_weights_dtype(nemo_checkpoint_path) + + if dtype is None: + raise ValueError( + "Parameter dtype needs to be provided and cannot be inferred from the checkpoint. " + "Please specify it explicitely." ) model, model_config, self.tokenizer = load_nemo_model( diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index ca725f74d2ef..211e804c61aa 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -17,13 +17,13 @@ import multiprocessing from collections import defaultdict from pathlib import Path -from typing import Optional, Union import torch from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch from tqdm import tqdm from nemo.export.trt_llm.converter.utils import save_scaling_factor, save_val, split_and_save_weight, weights_dict +from nemo.export.utils import torch_dtype_from_precision LOGGER = logging.getLogger("NeMo") @@ -36,22 +36,6 @@ } -def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: Optional[bool] = None) -> torch.dtype: - """Mapping from PTL precision types to corresponding PyTorch parameter datatype.""" - # Copied from nemo.collections.nlp.parts.utils_funcs to avoid extra depenencies for NIM. - if megatron_amp_O2 is not None and megatron_amp_O2 is False: - return torch.float32 - - if precision in ['bf16', 'bf16-mixed']: - return torch.bfloat16 - elif precision in [16, '16', '16-mixed']: - return torch.float16 - elif precision in [32, '32', '32-true']: - return torch.float32 - else: - raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype") - - def extract_layers_with_prefix(model_, prefix): length_to_trim = len(prefix) model_state = model_.get("state_dict", model_) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index f3c9812555bc..72f3c5c4bc25 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -34,6 +34,7 @@ from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.export.tarutils import TarPath, ZarrPathStore from nemo.export.tiktoken_tokenizer import TiktokenTokenizer +from nemo.export.utils import torch_dtype_from_precision try: from nemo.lightning import io @@ -441,7 +442,7 @@ def get_model_type(nemo_ckpt: Union[str, Path]) -> Optional[str]: Determine the model type from a NeMo checkpoint for TensorRT-LLM engine build. Args: - nemo_ckpt (str): Path to the NeMo checkpoint file. + nemo_ckpt (Union[str, Path]): Path to the NeMo checkpoint file. Returns: Optional[str]: The model type if it can be determined, otherwise None. """ @@ -480,6 +481,36 @@ def get_model_type(nemo_ckpt: Union[str, Path]) -> Optional[str]: return model_type +def get_weights_dtype(nemo_ckpt: Union[str, Path]) -> Optional[str]: + """Determine the weights data type from a NeMo checkpoint for TensorRT-LLM engine build. + + Args: + nemo_ckpt (Union[str, Path]): Path to the NeMo checkpoint file. + Returns: + Optional[str]: The dtype if it can be determined, otherwise None. + """ + model_config = load_nemo_config(nemo_ckpt) + torch_dtype = None + dtype = None + + is_nemo2 = "_target_" in model_config + if is_nemo2: + torch_dtype = model_config["config"]["params_dtype"]["_target_"] + elif precision := model_config.get("precision", None): + torch_dtype = str(torch_dtype_from_precision(precision)) + + if torch_dtype is not None: + dtype = torch_dtype.removeprefix("torch.") + LOGGER.info(f"Determined weights dtype='{dtype}' for {nemo_ckpt} checkpoint.") + else: + LOGGER.warning( + f"Parameter dtype for model weights cannot be determined for {nemo_ckpt} checkpoint. " + "There is no 'precision' field specified in the model_config.yaml file." + ) + + return dtype + + def load_distributed_model_weights( weights_directory: Union[Path, TarPath], mcore_scales_format: bool ) -> Dict[str, Any]: diff --git a/nemo/export/utils/__init__.py b/nemo/export/utils/__init__.py index edf000b93e59..ff66149bfa30 100644 --- a/nemo/export/utils/__init__.py +++ b/nemo/export/utils/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.export.utils.utils import is_nemo2_checkpoint +from nemo.export.utils.utils import is_nemo2_checkpoint, torch_dtype_from_precision -__all__ = ["is_nemo2_checkpoint"] +__all__ = [ + "is_nemo2_checkpoint", + "torch_dtype_from_precision", +] diff --git a/nemo/export/utils/utils.py b/nemo/export/utils/utils.py index 91208e7d7afa..705614ead296 100644 --- a/nemo/export/utils/utils.py +++ b/nemo/export/utils/utils.py @@ -13,6 +13,9 @@ # limitations under the License. from pathlib import Path +from typing import Union + +import torch def is_nemo2_checkpoint(checkpoint_path: str) -> bool: @@ -26,3 +29,28 @@ def is_nemo2_checkpoint(checkpoint_path: str) -> bool: ckpt_path = Path(checkpoint_path) return (ckpt_path / 'context').is_dir() + + +# Copied from nemo.collections.nlp.parts.utils_funcs to avoid introducing extra NeMo dependencies: +def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: bool = True) -> torch.dtype: + """ + Mapping from PyTorch Lighthing (PTL) precision types to corresponding PyTorch parameter data type. + + Args: + precision (Union[int, str]): The PTL precision type used. + megatron_amp_O2 (bool): A flag indicating if Megatron AMP O2 is enabled. + + Returns: + torch.dtype: The corresponding PyTorch data type based on the provided precision. + """ + if not megatron_amp_O2: + return torch.float32 + + if precision in ['bf16', 'bf16-mixed']: + return torch.bfloat16 + elif precision in [16, '16', '16-mixed']: + return torch.float16 + elif precision in [32, '32', '32-true']: + return torch.float32 + else: + raise ValueError(f"Could not parse the precision of '{precision}' to a valid torch.dtype")