# Copyright (c) Alibaba, Inc. and its affiliates. import argparse import os import shutil import sys import time import torch args = argparse.ArgumentParser(description='Process some integers.') args.add_argument( 'model_path', type=str, help='linear eval model path', nargs='?', default='') args.add_argument( 'work_dirname', type=str, help='evaluation work dir name', nargs='?', default='tmp_evaluation') args.add_argument( 'work_dirroot', type=str, help='evaluation work dir root', nargs='?', default='work_dirs/benchmarks/linear_classification/imagenet') args.add_argument( 'eval_config', type=str, help='evaluation work dir name', nargs='?', default='configs/benchmarks/linear_classification/imagenet/tmp_feature.py') TIME_LOG = [] def timelog(func): if type(func) == str: time_log = 'time_log %s : %s' % (time.asctime( time.localtime(time.time())), func) TIME_LOG.append(time_log) return def wrapper(*args, **kwargs): time_log = 'time_log %s : %s' % (time.asctime( time.localtime(time.time())), func.__name__) TIME_LOG.append(time_log) print(time_log) return func(*args, **kwargs) return wrapper @timelog def move_file(model_path, work_dir): if model_path[-4:] != '.pth': print('model path is Invalid!') exit() if not os.path.exists(work_dir): os.makedirs(work_dir) filename = model_path.split('/')[-1] filename = os.path.join(work_dir, filename) if os.path.exists(filename): return filename if model_path[:3] == 'oss': os.system('ossutil64 cp -f %s %s' % (model_path, work_dir)) else: shutil.copy(model_path, work_dir) return filename @timelog def extract_model(model_path): backbone_file = os.path.join(*(['/'] + model_path.split('/')[:-1] + ['backbone.pth'])) ck = torch.load(model_path, map_location=torch.device('cpu')) output_dict = dict(state_dict=dict()) has_backbone = False for key, value in ck['state_dict'].items(): if key.startswith('backbone'): output_dict['state_dict'][key[9:]] = value has_backbone = True if not has_backbone: raise Exception('Cannot find a backbone module in the checkpoint.') torch.save(output_dict, backbone_file) return backbone_file @timelog def extract_feature(project_path, work_dir, backbone_file): os.chdir(project_path) os.system( "PORT=29513 bash tools/dist_extract.sh configs/classification/imagenet/r50_extract.py 8 \ %s \ --pretrained=%s \ --layer-ind=\'4\' --dataset-config benchmarks/extract_info/imagenet.py" % (work_dir, backbone_file)) return def modify_config_file(config_file, keywords): lines = open(config_file).readlines() data = '' for l in lines: for k in keywords.keys(): # match keywords if ('%s=' % k in l or '%s =' % k in l) and '#' not in l: # in order to match the space before k= in l idx = max(l.find('%s =' % k), l.find('%s=' % k)) l = l[:idx] + "%s=\'%s\'\n" % (k, keywords[k]) data += l f = open(config_file, 'w') f.write(data) f.close() return @timelog def linear_eval(project_path, feature_path, config_file): os.chdir(project_path) keywords = { 'data_root_path': feature_path, } modify_config_file(config_file, keywords) os.system('sh tools/dist_train.sh %s 8' % config_file) return if __name__ == '__main__': project_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) args = args.parse_args() model_path = args.model_path work_dirname = args.work_dirname work_dirroot = os.path.join(project_path, args.work_dirroot) work_dir = os.path.join(work_dirroot, work_dirname) print('model_path : %s' % model_path) print('work_dirname : %s' % work_dirname) print('work_dir : %s' % work_dir) model_path = move_file(model_path, work_dir) backbone_file = extract_model(model_path) extract_feature(project_path, work_dir, backbone_file) linear_eval(project_path, os.path.join(work_dir, 'features'), args.eval_config) timelog('end') for l in TIME_LOG: print(l)