EasyCV/easycv/datasets/shared/pipelines/third_transforms_wrapper.py

88 lines
2.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import inspect
from enum import EnumMeta
import torch
from torchvision import transforms as _transforms
from easycv.datasets.registry import PIPELINES
def is_child_of(obj, cls):
try:
for i in obj.__bases__:
if i is cls or isinstance(i, cls):
return True
for i in obj.__bases__:
if is_child_of(i, cls):
return True
except AttributeError:
return is_child_of(obj.__class__, cls)
return False
def get_args(obj):
full_args_spec = inspect.getfullargspec(obj)
args = [] if not full_args_spec.args else full_args_spec.args
if (args and args[0] in ['self', 'cls']):
args.pop(0)
return args
def _reset_forward(obj):
original_forward = obj.forward
def _new_forward(self, results):
img = results['img']
img = original_forward(self, img)
results['img'] = img
return results
setattr(obj, 'forward', _new_forward)
def _reset_call(obj):
original_call = obj.__call__
def _new_call(self, results):
img = results['img']
img = original_call(self, img)
results['img'] = img
return results
setattr(obj, '__call__', _new_call)
# TODO: find a more pretty way to wrap third transfomrs or import fixed api to warp
def wrap_torchvision_transforms(transform_obj):
transform_obj = copy.deepcopy(transform_obj)
# args_format = ['img', 'pic']
if is_child_of(transform_obj, torch.nn.Module):
args = get_args(transform_obj.forward)
if len(args) == 1: # and args[0] in args_format:
_reset_forward(transform_obj)
elif hasattr(transform_obj, '__call__'):
args = get_args(transform_obj.__call__)
if len(args) == 1: # and args[0] in args_format:
_reset_call(transform_obj)
else:
pass
skip_list = ['Compose', 'RandomApply']
_transforms_names = locals()
# register all existing transforms in torchvision
for member in inspect.getmembers(_transforms, inspect.isclass):
obj_name, obj = member[0], member[1]
if obj_name in skip_list:
continue
if isinstance(obj, EnumMeta):
continue
_transforms_names[obj_name] = type(obj_name, (obj, ), dict())
wrap_torchvision_transforms(_transforms_names[obj_name])
PIPELINES.register_module(_transforms_names[obj_name])