mirror of https://github.com/alibaba/EasyCV.git
61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import argparse
|
|
import logging
|
|
|
|
from easycv.models import build_model
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
|
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Train a model')
|
|
parser.add_argument('config', help='train config file path')
|
|
parser.add_argument(
|
|
'--backbone_key',
|
|
type=str,
|
|
default='backbone',
|
|
help='backbone key name of model')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
cfg = mmcv_config_fromfile(args.config)
|
|
|
|
# dynamic adapt mmdet models
|
|
dynamic_adapt_for_mmlab(cfg)
|
|
|
|
model = build_model(cfg.model)
|
|
|
|
unit = 1e6
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) / unit
|
|
num_grad_params = sum(
|
|
p.numel() for p in model.parameters() if p.requires_grad) / unit
|
|
|
|
if hasattr(model, args.backbone_key):
|
|
num_backbone_params = sum(
|
|
p.numel()
|
|
for p in getattr(model, args.backbone_key).parameters()) / unit
|
|
num_backbone_grad_params = sum(
|
|
p.numel() for p in getattr(model, args.backbone_key).parameters()
|
|
if p.requires_grad) / unit
|
|
print('Number of backbone parameters: {:.5g} M'.format(
|
|
num_backbone_params))
|
|
print('Number of backbone parameters requiring grad: {:.5g} M'.format(
|
|
num_backbone_grad_params))
|
|
else:
|
|
logging.warning(
|
|
f'The backbone with the key ``{args.backbone_key}`` was not found in the model !'
|
|
)
|
|
|
|
print('Number of total parameters: {:.5g} M'.format(num_params))
|
|
print('Number of total parameters requiring grad: {:.5g} M'.format(
|
|
num_grad_params))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|