mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 添加了模型转换未验证,配置文件 但有无法运行 * 模型转换、结构验证完成,可以推理出正确答案 * 推理精度与原论文一致 已完成转化 * 三个方法改为class 暂存 * 完成推理精度对齐 误差0.04 * 暂时使用的levit2mmcls * 训练跑通,训练相关参数未对齐 * '训练相关参数对齐'参数' * '修复训练时验证导致模型结构改变无法复原问题' * '修复训练时验证导致模型结构改变无法复原问题' * '添加mixup和labelsmooth' * '配置文件补齐' * 添加模型转换 * 添加meta文件 * 添加meta文件 * 删除demo.py测试文件 * 添加模型README文件 * docs文件回滚 * model-index删除末行空格 * 更新模型metafile * 更新metafile * 更新metafile * 更新README和metafile * 更新模型README * 更新模型metafile * Delete the model class and get_LeViT_model methods in the mmcls.models.backone.levit file * Change the class name to Google Code Style * use arch to provide default architectures * use nn.Conv2d * mmcv.cnn.fuse_conv_bn * modify some details * remove down_ops from the architectures. * remove init_weight function * Modify ambiguous variable names * Change the drop_path in config to drop_path_rate * Add unit test * remove train function * add unit test * modify nn.norm1d to build_norm_layer * update metafile and readme * Update configs and LeViT implementations. * Update README. * Add docstring and update unit tests. * Revert irrelative modification. * Fix unit tests * minor fix Co-authored-by: mzr1996 <mzr1996@163.com>
115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from modelindex.load_model_index import load
|
|
from modelindex.models.Collection import Collection
|
|
from modelindex.models.Model import Model
|
|
from modelindex.models.ModelIndex import ModelIndex
|
|
|
|
prog_description = """\
|
|
Check the format of metafile.
|
|
"""
|
|
|
|
MMCLS_ROOT = Path(__file__).absolute().parents[1]
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description=prog_description)
|
|
parser.add_argument(
|
|
'metafile', type=Path, nargs='+', help='The path of the matafile.')
|
|
parser.add_argument(
|
|
'--Wall',
|
|
'-w',
|
|
action='store_true',
|
|
help='Whether to enable all warnings.')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def check_collection(modelindex: ModelIndex):
|
|
if len(modelindex.collections) != 1:
|
|
return 'One metafile should have only one collection.'
|
|
collection: Collection = modelindex.collections[0]
|
|
if collection.name is None:
|
|
return 'The collection should have `Name` field.'
|
|
if collection.readme is None:
|
|
return 'The collection should have `README` field.'
|
|
if not (MMCLS_ROOT / collection.readme).exists():
|
|
return f'The README {collection.readme} is not found.'
|
|
if not isinstance(collection.paper, dict):
|
|
return ('The collection should have `Paper` field with '
|
|
'`Title` and `URL`.')
|
|
if 'Title' not in collection.paper:
|
|
# URL is not necessary.
|
|
return "The collection's paper should have `Paper` field."
|
|
|
|
|
|
def check_model(model: Model, wall=True):
|
|
if model.name is None:
|
|
return "A model doesn't have `Name` field."
|
|
if model.metadata is None:
|
|
return f'{model.name}: No `Metadata` field.'
|
|
if model.metadata.parameters is None or model.metadata.flops is None:
|
|
return (
|
|
f'{model.name}: Metadata should have `Parameters` and '
|
|
'`FLOPs` fields. You can use `tools/analysis_tools/get_flops.py` '
|
|
'to calculate them.')
|
|
if model.results is not None:
|
|
result = model.results[0]
|
|
if not isinstance(result.dataset, str):
|
|
return (
|
|
f'{model.name}: Dataset field of Results should be a string. '
|
|
'If you want to specify the training dataset, please use '
|
|
'`Metadata.Training Data` field.')
|
|
if model.config is None:
|
|
return f'{model.name}: No `Config` field.'
|
|
if not (MMCLS_ROOT / model.config).exists():
|
|
return f'{model.name}: The config {model.config} is not found.'
|
|
if model.in_collection is None:
|
|
return f'{model.name}: No `In Collection` field.'
|
|
|
|
if wall and model.data.get(
|
|
'Converted From') is not None and '3rdparty' not in model.name:
|
|
print(f'WARN: The model name {model.name} should include '
|
|
"'3rdparty' since it's converted from other repository.")
|
|
if wall and model.weights is not None and model.weights.endswith('.pth'):
|
|
basename = model.weights.rsplit('/', 1)[-1]
|
|
if not basename.startswith(model.name):
|
|
print(f'WARN: The checkpoint name {basename} is not the '
|
|
f'same as the model name {model.name}.')
|
|
|
|
|
|
def main(metafile: Path, args):
|
|
if metafile.name != 'metafile.yml':
|
|
# Avoid checking other yaml file.
|
|
return
|
|
elif metafile.samefile(MMCLS_ROOT / 'model-index.yml'):
|
|
return
|
|
|
|
with open(MMCLS_ROOT / 'model-index.yml', 'r') as f:
|
|
metafile_list = yaml.load(f, yaml.Loader)['Import']
|
|
if not any(
|
|
metafile.samefile(MMCLS_ROOT / file)
|
|
for file in metafile_list):
|
|
raise ValueError(f'The metafile {metafile} is not imported in '
|
|
'the `model-index.yml`.')
|
|
|
|
modelindex = load(str(metafile))
|
|
modelindex.build_models_with_collections()
|
|
collection_err = check_collection(modelindex)
|
|
if collection_err is not None:
|
|
raise ValueError(f'The `Collections` in the {metafile} is wrong:'
|
|
f'\n\t{collection_err}')
|
|
for model in modelindex.models:
|
|
model_err = check_model(model, args.Wall)
|
|
if model_err is not None:
|
|
raise ValueError(
|
|
f'The `Models` in the {metafile} is wrong:\n\t{model_err}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
for metafile in args.metafile:
|
|
main(metafile, args)
|