EasyCV/easycv/datasets/shared/pipelines/third_transforms_wrapper.py
yhq 04ca0d53c7
Doc/add mae benchmark (#24)
* add mae large to mae benchmark
* fix dataset pipeline bug  when timm using torhvision original pipeline in mae
* add mae fintune unittest

Co-authored-by: yanhaiqiang.yhq <yanhaiqiang.yhq@gitlab.alibaba-inc.com>
2022-04-24 14:53:35 +08:00

84 lines
2.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import inspect
import torch
from torchvision import transforms as _transforms
from easycv.datasets.registry import PIPELINES
def is_child_of(obj, cls):
try:
for i in obj.__bases__:
if i is cls or isinstance(i, cls):
return True
for i in obj.__bases__:
if is_child_of(i, cls):
return True
except AttributeError:
return is_child_of(obj.__class__, cls)
return False
def get_args(obj):
full_args_spec = inspect.getfullargspec(obj)
args = [] if not full_args_spec.args else full_args_spec.args
if (args and args[0] in ['self', 'cls']):
args.pop(0)
return args
def _reset_forward(obj):
original_forward = obj.forward
def _new_forward(self, results):
img = results['img']
img = original_forward(self, img)
results['img'] = img
return results
setattr(obj, 'forward', _new_forward)
def _reset_call(obj):
original_call = obj.__call__
def _new_call(self, results):
img = results['img']
img = original_call(self, img)
results['img'] = img
return results
setattr(obj, '__call__', _new_call)
# TODO: find a more pretty way to wrap third transfomrs or import fixed api to warp
def wrap_torchvision_transforms(transform_obj):
transform_obj = copy.deepcopy(transform_obj)
# args_format = ['img', 'pic']
if is_child_of(transform_obj, torch.nn.Module):
args = get_args(transform_obj.forward)
if len(args) == 1: # and args[0] in args_format:
_reset_forward(transform_obj)
elif hasattr(transform_obj, '__call__'):
args = get_args(transform_obj.__call__)
if len(args) == 1: # and args[0] in args_format:
_reset_call(transform_obj)
else:
pass
skip_list = ['Compose', 'RandomApply']
# register all existing transforms in torchvision
for member in inspect.getmembers(_transforms, inspect.isclass):
obj_name, obj = member[0], member[1]
if obj_name in skip_list:
continue
obj_copy = type(obj_name, (obj, ), dict())
wrap_torchvision_transforms(obj_copy)
PIPELINES.register_module(obj_copy)