65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Any, Type
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
class ExtendModule:
|
|
"""Combine the base language model with adapter. This module will create a
|
|
instance from base with extended functions in adapter.
|
|
|
|
Args:
|
|
base (object): Base module could be any object that represent
|
|
a instance of language model or a dict that can build the
|
|
base module.
|
|
adapter: (dict): Dict to build the adapter.
|
|
"""
|
|
|
|
def __new__(cls, base: object, adapter: dict):
|
|
|
|
if isinstance(base, dict):
|
|
base = MODELS.build(base)
|
|
|
|
adapter_module = MODELS.get(adapter.pop('type'))
|
|
cls.extend_instance(base, adapter_module)
|
|
return adapter_module.extend_init(base, **adapter)
|
|
|
|
@classmethod
|
|
def extend_instance(cls, base: object, mixin: Type[Any]):
|
|
"""Apply mixins to a class instance after creation.
|
|
|
|
Args:
|
|
base (object): Base module instance.
|
|
mixin: (Type[Any]): Adapter class type to mixin.
|
|
"""
|
|
base_cls = base.__class__
|
|
base_cls_name = base.__class__.__name__
|
|
base.__class__ = type(
|
|
base_cls_name, (mixin, base_cls),
|
|
{}) # mixin needs to go first for our forward() logic to work
|
|
|
|
|
|
def getattr_recursive(obj, att):
|
|
"""
|
|
Return nested attribute of obj
|
|
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
|
"""
|
|
if att == '':
|
|
return obj
|
|
i = att.find('.')
|
|
if i < 0:
|
|
return getattr(obj, att)
|
|
else:
|
|
return getattr_recursive(getattr(obj, att[:i]), att[i + 1:])
|
|
|
|
|
|
def setattr_recursive(obj, att, val):
|
|
"""
|
|
Set nested attribute of obj
|
|
Example: setattr_recursive(obj, 'a.b.c', val)
|
|
is equivalent to obj.a.b.c = val
|
|
"""
|
|
if '.' in att:
|
|
obj = getattr_recursive(obj, '.'.join(att.split('.')[:-1]))
|
|
setattr(obj, att.split('.')[-1], val)
|