mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add heads * add losses * fix * remove mim head * add modified backbones and target generators * fix lint * fix lint * add heads * add losses * fix * add data preprocessor from mmselfsup * add ut for data prepocessor * add GatherLayer * add ema * add batch shuffle * add misc * fix lint * update * update docstring
19 lines
466 B
Python
19 lines
466 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
from mmengine.model import is_model_wrapper
|
|
|
|
|
|
def get_ori_model(model: nn.Module) -> nn.Module:
|
|
"""Get original model if the input model is a model wrapper.
|
|
|
|
Args:
|
|
model (nn.Module): A model may be a model wrapper.
|
|
|
|
Returns:
|
|
nn.Module: The model without model wrapper.
|
|
"""
|
|
if is_model_wrapper(model):
|
|
return model.module
|
|
else:
|
|
return model
|