mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
19 lines
466 B
Python
19 lines
466 B
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import torch.nn as nn
|
||
|
from mmengine.model import is_model_wrapper
|
||
|
|
||
|
|
||
|
def get_ori_model(model: nn.Module) -> nn.Module:
|
||
|
"""Get original model if the input model is a model wrapper.
|
||
|
|
||
|
Args:
|
||
|
model (nn.Module): A model may be a model wrapper.
|
||
|
|
||
|
Returns:
|
||
|
nn.Module: The model without model wrapper.
|
||
|
"""
|
||
|
if is_model_wrapper(model):
|
||
|
return model.module
|
||
|
else:
|
||
|
return model
|