53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
# ------------------------------------------------------------------------
|
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
|
# ------------------------------------------------------------------------
|
|
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
|
# Copyright 2018-2020 BasicSR Authors
|
|
# ------------------------------------------------------------------------
|
|
import importlib
|
|
from os import path as osp
|
|
|
|
from basicsr.utils import scandir
|
|
|
|
# automatically scan and import arch modules
|
|
# scan all the files under the 'archs' folder and collect files ending with
|
|
# '_arch.py'
|
|
arch_folder = osp.dirname(osp.abspath(__file__))
|
|
arch_filenames = [
|
|
osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
|
|
if v.endswith('_arch.py')
|
|
]
|
|
# import all the arch modules
|
|
_arch_modules = [
|
|
importlib.import_module(f'basicsr.models.archs.{file_name}')
|
|
for file_name in arch_filenames
|
|
]
|
|
|
|
|
|
def dynamic_instantiation(modules, cls_type, opt):
|
|
"""Dynamically instantiate class.
|
|
|
|
Args:
|
|
modules (list[importlib modules]): List of modules from importlib
|
|
files.
|
|
cls_type (str): Class type.
|
|
opt (dict): Class initialization kwargs.
|
|
|
|
Returns:
|
|
class: Instantiated class.
|
|
"""
|
|
|
|
for module in modules:
|
|
cls_ = getattr(module, cls_type, None)
|
|
if cls_ is not None:
|
|
break
|
|
if cls_ is None:
|
|
raise ValueError(f'{cls_type} is not found.')
|
|
return cls_(**opt)
|
|
|
|
|
|
def define_network(opt):
|
|
network_type = opt.pop('type')
|
|
net = dynamic_instantiation(_arch_modules, network_type, opt)
|
|
return net
|