22 lines
567 B
Python
22 lines
567 B
Python
|
_model_entrypoints = {}
|
||
|
|
||
|
|
||
|
def build_model(config, **kwargs):
|
||
|
model_name = config['MODEL']['NAME']
|
||
|
|
||
|
if not is_model(model_name):
|
||
|
raise ValueError(f'Unkown model: {model_name}')
|
||
|
|
||
|
return model_entrypoints(model_name)(config, **kwargs)
|
||
|
|
||
|
def register_model(fn):
|
||
|
module_name_split = fn.__module__.split('.')
|
||
|
model_name = module_name_split[-1]
|
||
|
_model_entrypoints[model_name] = fn
|
||
|
return fn
|
||
|
|
||
|
def model_entrypoints(model_name):
|
||
|
return _model_entrypoints[model_name]
|
||
|
|
||
|
def is_model(model_name):
|
||
|
return model_name in _model_entrypoints
|