diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 048433e070..c52565dbbb 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -126,6 +126,10 @@ def build_pytorch_model(self, if hasattr(model, 'backbone') and hasattr(model.backbone, 'switch_to_deploy'): model.backbone.switch_to_deploy() + + if hasattr(model, 'switch_to_deploy') and callable(model.switch_to_deploy): + model.switch_to_deploy() + model = model.to(self.device) model.eval() return model diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 9621b92f38..5e6b0c5c6f 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -import inspect import os from collections import defaultdict from typing import Callable, Dict, Optional, Sequence, Tuple, Union @@ -180,41 +179,6 @@ def build_backend_model( **kwargs) return model.eval().to(self.device) - def build_pytorch_model(self, - model_checkpoint: Optional[str] = None, - cfg_options: Optional[Dict] = None, - **kwargs) -> torch.nn.Module: - """Initialize torch model and switch to deploy mode. - - Args: - model_checkpoint (str): The checkpoint file of torch model, - defaults to `None`. - cfg_options (dict): Optional config key-pair parameters. - - Returns: - nn.Module: An initialized torch model generated by other OpenMMLab - codebases. - """ - # Initialize the PyTorch model using parent class method - torch_model = super().build_pytorch_model(model_checkpoint, - cfg_options, **kwargs) - - # Check if called from torch2onnx within 'apis/pytorch2onnx.py' - callers = inspect.stack() - is_torch2onnx_call = ( - len(callers) > 1 and callers[1].function == 'torch2onnx' - and callers[1].filename.endswith( - os.path.join('apis', 'pytorch2onnx.py'))) - - # If model has a 'switch_to_deploy' method and is called from - # torch2onnx, activate this method - if is_torch2onnx_call and hasattr(torch_model, - 'switch_to_deploy') and callable( - torch_model.switch_to_deploy): - torch_model.switch_to_deploy() - - return torch_model - def create_input(self, imgs: Union[str, np.ndarray, Sequence], input_shape: Sequence[int] = None,