mirror of https://github.com/alibaba/EasyCV.git
88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import copy
|
|
import inspect
|
|
from enum import EnumMeta
|
|
|
|
import torch
|
|
from torchvision import transforms as _transforms
|
|
|
|
from easycv.datasets.registry import PIPELINES
|
|
|
|
|
|
def is_child_of(obj, cls):
|
|
try:
|
|
for i in obj.__bases__:
|
|
if i is cls or isinstance(i, cls):
|
|
return True
|
|
for i in obj.__bases__:
|
|
if is_child_of(i, cls):
|
|
return True
|
|
except AttributeError:
|
|
return is_child_of(obj.__class__, cls)
|
|
return False
|
|
|
|
|
|
def get_args(obj):
|
|
full_args_spec = inspect.getfullargspec(obj)
|
|
args = [] if not full_args_spec.args else full_args_spec.args
|
|
|
|
if (args and args[0] in ['self', 'cls']):
|
|
args.pop(0)
|
|
|
|
return args
|
|
|
|
|
|
def _reset_forward(obj):
|
|
original_forward = obj.forward
|
|
|
|
def _new_forward(self, results):
|
|
img = results['img']
|
|
img = original_forward(self, img)
|
|
results['img'] = img
|
|
return results
|
|
|
|
setattr(obj, 'forward', _new_forward)
|
|
|
|
|
|
def _reset_call(obj):
|
|
original_call = obj.__call__
|
|
|
|
def _new_call(self, results):
|
|
img = results['img']
|
|
img = original_call(self, img)
|
|
results['img'] = img
|
|
|
|
return results
|
|
|
|
setattr(obj, '__call__', _new_call)
|
|
|
|
|
|
# TODO: find a more pretty way to wrap third transfomrs or import fixed api to warp
|
|
def wrap_torchvision_transforms(transform_obj):
|
|
transform_obj = copy.deepcopy(transform_obj)
|
|
# args_format = ['img', 'pic']
|
|
if is_child_of(transform_obj, torch.nn.Module):
|
|
args = get_args(transform_obj.forward)
|
|
if len(args) == 1: # and args[0] in args_format:
|
|
_reset_forward(transform_obj)
|
|
elif hasattr(transform_obj, '__call__'):
|
|
args = get_args(transform_obj.__call__)
|
|
if len(args) == 1: # and args[0] in args_format:
|
|
_reset_call(transform_obj)
|
|
else:
|
|
pass
|
|
|
|
|
|
skip_list = ['Compose', 'RandomApply']
|
|
_transforms_names = locals()
|
|
# register all existing transforms in torchvision
|
|
for member in inspect.getmembers(_transforms, inspect.isclass):
|
|
obj_name, obj = member[0], member[1]
|
|
if obj_name in skip_list:
|
|
continue
|
|
if isinstance(obj, EnumMeta):
|
|
continue
|
|
_transforms_names[obj_name] = type(obj_name, (obj, ), dict())
|
|
wrap_torchvision_transforms(_transforms_names[obj_name])
|
|
PIPELINES.register_module(_transforms_names[obj_name])
|