Skip to content

Commit

Permalink
Autodetect dtype on exporting to TensorRT-LLM (#11907)
Browse files Browse the repository at this point in the history
* Autodetect dtype from NeMo checkpoint for TRT-LLM export

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Attempting parallel build

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Revert "Attempting parallel build"

This reverts commit b3c2db0.

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Update logs and error messages

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Try autodetecting dtype parameter by default

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

---------

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
Co-authored-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>
  • Loading branch information
janekl and oyilmaz-nvidia authored Jan 23, 2025
1 parent 5042b79 commit 116b5cd
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 31 deletions.
33 changes: 22 additions & 11 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,13 @@
from nemo.deploy import ITritonDeployable
from nemo.export.tarutils import TarPath, unpack_tarball
from nemo.export.trt_llm.converter.model_converter import determine_quantization_settings, model_to_trtllm_ckpt
from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import (
dist_model_to_trt_llm_ckpt,
get_layer_prefix,
torch_dtype_from_precision,
)
from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt, get_layer_prefix
from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo
from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import (
build_tokenizer,
get_model_type,
get_tokenizer,
get_weights_dtype,
is_nemo_file,
load_nemo_model,
)
Expand All @@ -59,6 +56,7 @@
unload_engine,
)
from nemo.export.trt_llm.utils import is_rank
from nemo.export.utils import torch_dtype_from_precision

use_deploy = True
try:
Expand Down Expand Up @@ -170,7 +168,7 @@ def export(
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
dtype: str = "bfloat16",
dtype: Optional[str] = None,
load_model: bool = True,
use_lora_plugin: str = None,
lora_target_modules: List[str] = None,
Expand Down Expand Up @@ -208,7 +206,8 @@ def export(
paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM.
paged_context_fmha (bool): whether to use paged context fmha feature of TRT-LLM or not
remove_input_padding (bool): enables removing input padding or not.
dtype (str): Floating point type for model weights (Supports BFloat16/Float16).
dtype (Optional[str]): Floating point type for model weights (supports 'bfloat16', 'float16' or 'float32').
If None, try to autodetect the type from model config.
load_model (bool): load TensorRT-LLM model after the export.
use_lora_plugin (str): use dynamic lora or not.
lora_target_modules (List[str]): list of the target lora modules.
Expand Down Expand Up @@ -316,12 +315,24 @@ def export(
model_type = get_model_type(nemo_checkpoint_path)

if model_type is None:
raise Exception("model_type needs to be specified, got None.")
raise ValueError(
"Parameter model_type needs to be provided and cannot be inferred from the checkpoint. "
"Please specify it explicitely."
)

if model_type not in self.get_supported_models_list:
raise Exception(
"Model {0} is not currently a supported model type. "
"Supported model types are: {1}.".format(model_type, self.get_supported_models_list)
raise ValueError(
f"Model {model_type} is not currently a supported model type. "
f"Supported model types are: {self.get_supported_models_list}."
)

if dtype is None:
dtype = get_weights_dtype(nemo_checkpoint_path)

if dtype is None:
raise ValueError(
"Parameter dtype needs to be provided and cannot be inferred from the checkpoint. "
"Please specify it explicitely."
)

model, model_config, self.tokenizer = load_nemo_model(
Expand Down
18 changes: 1 addition & 17 deletions nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import multiprocessing
from collections import defaultdict
from pathlib import Path
from typing import Optional, Union

import torch
from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch
from tqdm import tqdm

from nemo.export.trt_llm.converter.utils import save_scaling_factor, save_val, split_and_save_weight, weights_dict
from nemo.export.utils import torch_dtype_from_precision

LOGGER = logging.getLogger("NeMo")

Expand All @@ -36,22 +36,6 @@
}


def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: Optional[bool] = None) -> torch.dtype:
"""Mapping from PTL precision types to corresponding PyTorch parameter datatype."""
# Copied from nemo.collections.nlp.parts.utils_funcs to avoid extra depenencies for NIM.
if megatron_amp_O2 is not None and megatron_amp_O2 is False:
return torch.float32

if precision in ['bf16', 'bf16-mixed']:
return torch.bfloat16
elif precision in [16, '16', '16-mixed']:
return torch.float16
elif precision in [32, '32', '32-true']:
return torch.float32
else:
raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype")


def extract_layers_with_prefix(model_, prefix):
length_to_trim = len(prefix)
model_state = model_.get("state_dict", model_)
Expand Down
33 changes: 32 additions & 1 deletion nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.export.tarutils import TarPath, ZarrPathStore
from nemo.export.tiktoken_tokenizer import TiktokenTokenizer
from nemo.export.utils import torch_dtype_from_precision

try:
from nemo.lightning import io
Expand Down Expand Up @@ -441,7 +442,7 @@ def get_model_type(nemo_ckpt: Union[str, Path]) -> Optional[str]:
Determine the model type from a NeMo checkpoint for TensorRT-LLM engine build.
Args:
nemo_ckpt (str): Path to the NeMo checkpoint file.
nemo_ckpt (Union[str, Path]): Path to the NeMo checkpoint file.
Returns:
Optional[str]: The model type if it can be determined, otherwise None.
"""
Expand Down Expand Up @@ -480,6 +481,36 @@ def get_model_type(nemo_ckpt: Union[str, Path]) -> Optional[str]:
return model_type


def get_weights_dtype(nemo_ckpt: Union[str, Path]) -> Optional[str]:
"""Determine the weights data type from a NeMo checkpoint for TensorRT-LLM engine build.
Args:
nemo_ckpt (Union[str, Path]): Path to the NeMo checkpoint file.
Returns:
Optional[str]: The dtype if it can be determined, otherwise None.
"""
model_config = load_nemo_config(nemo_ckpt)
torch_dtype = None
dtype = None

is_nemo2 = "_target_" in model_config
if is_nemo2:
torch_dtype = model_config["config"]["params_dtype"]["_target_"]
elif precision := model_config.get("precision", None):
torch_dtype = str(torch_dtype_from_precision(precision))

if torch_dtype is not None:
dtype = torch_dtype.removeprefix("torch.")
LOGGER.info(f"Determined weights dtype='{dtype}' for {nemo_ckpt} checkpoint.")
else:
LOGGER.warning(
f"Parameter dtype for model weights cannot be determined for {nemo_ckpt} checkpoint. "
"There is no 'precision' field specified in the model_config.yaml file."
)

return dtype


def load_distributed_model_weights(
weights_directory: Union[Path, TarPath], mcore_scales_format: bool
) -> Dict[str, Any]:
Expand Down
7 changes: 5 additions & 2 deletions nemo/export/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.export.utils.utils import is_nemo2_checkpoint
from nemo.export.utils.utils import is_nemo2_checkpoint, torch_dtype_from_precision

__all__ = ["is_nemo2_checkpoint"]
__all__ = [
"is_nemo2_checkpoint",
"torch_dtype_from_precision",
]
28 changes: 28 additions & 0 deletions nemo/export/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

from pathlib import Path
from typing import Union

import torch


def is_nemo2_checkpoint(checkpoint_path: str) -> bool:
Expand All @@ -26,3 +29,28 @@ def is_nemo2_checkpoint(checkpoint_path: str) -> bool:

ckpt_path = Path(checkpoint_path)
return (ckpt_path / 'context').is_dir()


# Copied from nemo.collections.nlp.parts.utils_funcs to avoid introducing extra NeMo dependencies:
def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: bool = True) -> torch.dtype:
"""
Mapping from PyTorch Lighthing (PTL) precision types to corresponding PyTorch parameter data type.
Args:
precision (Union[int, str]): The PTL precision type used.
megatron_amp_O2 (bool): A flag indicating if Megatron AMP O2 is enabled.
Returns:
torch.dtype: The corresponding PyTorch data type based on the provided precision.
"""
if not megatron_amp_O2:
return torch.float32

if precision in ['bf16', 'bf16-mixed']:
return torch.bfloat16
elif precision in [16, '16', '16-mixed']:
return torch.float16
elif precision in [32, '32', '32-true']:
return torch.float32
else:
raise ValueError(f"Could not parse the precision of '{precision}' to a valid torch.dtype")

0 comments on commit 116b5cd

Please sign in to comment.