mmrazor/tools/model_converters/convert_supernet2subnet.py

61 lines
1.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import torch
from mmengine.config import Config
from mmengine.runner import Runner
from mmrazor.structures.subnet import load_fix_subnet
from mmrazor.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(
description='Process a NAS supernet checkpoint to be converted')
parser.add_argument('config', help='NAS model config file path')
parser.add_argument('checkpoint', help='supernet checkpoint file path')
parser.add_argument('yaml', help='YAML with subnet settings file path')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
args = parser.parse_args()
return args
def main():
register_all_modules(False)
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
cfg.load_from = args.checkpoint
cfg.work_dir = '/'.join(args.checkpoint.split('/')[:-1])
runner = Runner.from_cfg(cfg)
load_fix_subnet(runner.model, args.yaml)
timestamp_subnet = time.strftime('%Y%m%d_%H%M', time.localtime())
model_name = f'subnet_{timestamp_subnet}.pth'
save_path = osp.join(runner.work_dir, model_name)
torch.save({
'state_dict': runner.model.state_dict(),
'meta': {}
}, save_path)
runner.logger.info(f'Successful converted. Saved in {save_path}.')
if __name__ == '__main__':
main()