Skip to content

Commit

Permalink
apply model.switch_to_deploy in BaseTask.build_pytorch_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Dec 13, 2023
1 parent c3ec981 commit 2207861
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 36 deletions.
4 changes: 4 additions & 0 deletions mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def build_pytorch_model(self,
if hasattr(model, 'backbone') and hasattr(model.backbone,
'switch_to_deploy'):
model.backbone.switch_to_deploy()

if hasattr(model, 'switch_to_deploy') and callable(model.switch_to_deploy):
model.switch_to_deploy()

model = model.to(self.device)
model.eval()
return model
Expand Down
36 changes: 0 additions & 36 deletions mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

import copy
import inspect
import os
from collections import defaultdict
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -180,41 +179,6 @@ def build_backend_model(
**kwargs)
return model.eval().to(self.device)

def build_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
"""Initialize torch model and switch to deploy mode.
Args:
model_checkpoint (str): The checkpoint file of torch model,
defaults to `None`.
cfg_options (dict): Optional config key-pair parameters.
Returns:
nn.Module: An initialized torch model generated by other OpenMMLab
codebases.
"""
# Initialize the PyTorch model using parent class method
torch_model = super().build_pytorch_model(model_checkpoint,
cfg_options, **kwargs)

# Check if called from torch2onnx within 'apis/pytorch2onnx.py'
callers = inspect.stack()
is_torch2onnx_call = (
len(callers) > 1 and callers[1].function == 'torch2onnx'
and callers[1].filename.endswith(
os.path.join('apis', 'pytorch2onnx.py')))

# If model has a 'switch_to_deploy' method and is called from
# torch2onnx, activate this method
if is_torch2onnx_call and hasattr(torch_model,
'switch_to_deploy') and callable(
torch_model.switch_to_deploy):
torch_model.switch_to_deploy()

return torch_model

def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None,
Expand Down

0 comments on commit 2207861

Please sign in to comment.