mirror of https://github.com/open-mmlab/mmyolo.git
[Fix] Enable switch to deploy when create pytorch model in deployment. (#324)
* support switch_to_deploy when deploy * fix docformatter * fix README for yolox shapepull/754/head
parent
68c9fd4745
commit
4a8699d6fe
|
@ -16,7 +16,7 @@ In this report, we present some experienced improvements to YOLO series, forming
|
|||
|
||||
| Backbone | size | Mem (GB) | box AP | Config | Download |
|
||||
| :--------: | :--: | :------: | :----: | :---------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOX-tiny | 640 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
|
||||
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
|
||||
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
|
||||
|
||||
**Note**:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from mmdeploy.codebase.base import CODEBASE, MMCodebase
|
||||
from mmdeploy.codebase.mmdet.deploy import ObjectDetection
|
||||
from mmdeploy.utils import Codebase, Task
|
||||
|
@ -82,3 +83,40 @@ class YOLOObjectDetection(ObjectDetection):
|
|||
if metainfo is not None:
|
||||
visualizer.dataset_meta = metainfo
|
||||
return visualizer
|
||||
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
model_checkpoint (str): The checkpoint file of torch model,
|
||||
defaults to `None`.
|
||||
cfg_options (dict): Optional config key-pair parameters.
|
||||
Returns:
|
||||
nn.Module: An initialized torch model generated by other OpenMMLab
|
||||
codebases.
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
from mmyolo.utils import switch_to_deploy
|
||||
|
||||
model = deepcopy(self.model_cfg.model)
|
||||
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
|
||||
preprocess_cfg.update(
|
||||
deepcopy(self.model_cfg.get('data_preprocessor', {})))
|
||||
model.setdefault('data_preprocessor', preprocess_cfg)
|
||||
model = MODELS.build(model)
|
||||
if model_checkpoint is not None:
|
||||
from mmengine.runner.checkpoint import load_checkpoint
|
||||
load_checkpoint(model, model_checkpoint, map_location=self.device)
|
||||
|
||||
model = revert_sync_batchnorm(model)
|
||||
switch_to_deploy(model)
|
||||
model = model.to(self.device)
|
||||
model.eval()
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue