mirror of https://github.com/YifanXu74/MQ-Det.git
358 lines
13 KiB
Python
358 lines
13 KiB
Python
|
|
import time
|
|
import pickle
|
|
import logging
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
from collections import OrderedDict
|
|
from yaml import safe_dump
|
|
from yacs.config import load_cfg, CfgNode#, _to_dict
|
|
from maskrcnn_benchmark.config import cfg
|
|
from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus
|
|
from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name
|
|
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather
|
|
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
|
|
from maskrcnn_benchmark.utils.flops import profile
|
|
|
|
|
|
choice = lambda x:x[np.random.randint(len(x))] if isinstance(x,tuple) else choice(tuple(x))
|
|
|
|
|
|
def gather_candidates(all_candidates):
|
|
all_candidates = all_gather(all_candidates)
|
|
all_candidates = [cand for candidates in all_candidates for cand in candidates]
|
|
return list(set(all_candidates))
|
|
|
|
|
|
def gather_stats(all_candidates):
|
|
all_candidates = all_gather(all_candidates)
|
|
reduced_statcs = {}
|
|
for candidates in all_candidates:
|
|
reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists
|
|
return reduced_statcs
|
|
|
|
|
|
def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE):
|
|
model.eval()
|
|
results_dict = {}
|
|
cpu_device = torch.device("cpu")
|
|
for _, batch in enumerate(data_loader):
|
|
images, targets, image_ids = batch
|
|
with torch.no_grad():
|
|
output = model(images.to(device), rngs=rngs)
|
|
output = [o.to(cpu_device) for o in output]
|
|
results_dict.update(
|
|
{img_id: result for img_id, result in zip(image_ids, output)}
|
|
)
|
|
return results_dict
|
|
|
|
|
|
def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500):
|
|
for name, param in model.named_buffers():
|
|
if 'running_mean' in name:
|
|
nn.init.constant_(param, 0)
|
|
if 'running_var' in name:
|
|
nn.init.constant_(param, 1)
|
|
|
|
model.train()
|
|
for iteration, (images, targets, _) in enumerate(data_loader, 1):
|
|
images = images.to(device)
|
|
targets = [target.to(device) for target in targets]
|
|
with torch.no_grad():
|
|
loss_dict = model(images, targets, rngs)
|
|
if iteration >= max_iter:
|
|
break
|
|
|
|
return model
|
|
|
|
|
|
def inference(
|
|
model,
|
|
rngs,
|
|
data_loader,
|
|
iou_types=("bbox",),
|
|
box_only=False,
|
|
device="cuda",
|
|
expected_results=(),
|
|
expected_results_sigma_tol=4,
|
|
output_folder=None,
|
|
):
|
|
|
|
# convert to a torch.device for efficiency
|
|
device = torch.device(device)
|
|
dataset = data_loader.dataset
|
|
predictions = compute_on_dataset(model, rngs, data_loader, device)
|
|
# wait for all processes to complete before measuring the time
|
|
synchronize()
|
|
|
|
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
|
|
if not is_main_process():
|
|
return
|
|
|
|
extra_args = dict(
|
|
box_only=box_only,
|
|
iou_types=iou_types,
|
|
expected_results=expected_results,
|
|
expected_results_sigma_tol=expected_results_sigma_tol,
|
|
)
|
|
|
|
return evaluate(dataset=dataset,
|
|
predictions=predictions,
|
|
output_folder=output_folder,
|
|
**extra_args)
|
|
|
|
|
|
def fitness(cfg, model, rngs, val_loaders):
|
|
iou_types = ("bbox",)
|
|
if cfg.MODEL.MASK_ON:
|
|
iou_types = iou_types + ("segm",)
|
|
for data_loader_val in val_loaders:
|
|
results = inference(
|
|
model,
|
|
rngs,
|
|
data_loader_val,
|
|
iou_types=iou_types,
|
|
box_only=False,
|
|
device=cfg.MODEL.DEVICE,
|
|
expected_results=cfg.TEST.EXPECTED_RESULTS,
|
|
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
|
|
)
|
|
synchronize()
|
|
|
|
return results
|
|
|
|
|
|
class EvolutionTrainer(object):
|
|
def __init__(self, cfg, model, flops_limit=None, is_distributed=True):
|
|
|
|
self.log_dir = cfg.OUTPUT_DIR
|
|
self.checkpoint_name = os.path.join(self.log_dir,'evolution.pth')
|
|
self.is_distributed = is_distributed
|
|
|
|
self.states = model.module.mix_nums if is_distributed else model.mix_nums
|
|
self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict()))
|
|
self.flops_limit = flops_limit
|
|
self.model = model
|
|
|
|
self.candidates = []
|
|
self.vis_dict = {}
|
|
|
|
self.max_epochs = cfg.SEARCH.MAX_EPOCH
|
|
self.select_num = cfg.SEARCH.SELECT_NUM
|
|
self.population_num = cfg.SEARCH.POPULATION_NUM/get_world_size()
|
|
self.mutation_num = cfg.SEARCH.MUTATION_NUM/get_world_size()
|
|
self.crossover_num = cfg.SEARCH.CROSSOVER_NUM/get_world_size()
|
|
self.mutation_prob = cfg.SEARCH.MUTATION_PROB/get_world_size()
|
|
|
|
self.keep_top_k = {self.select_num:[], 50:[]}
|
|
self.epoch=0
|
|
self.cfg = cfg
|
|
|
|
def save_checkpoint(self):
|
|
if not is_main_process():
|
|
return
|
|
if not os.path.exists(self.log_dir):
|
|
os.makedirs(self.log_dir)
|
|
info = {}
|
|
info['candidates'] = self.candidates
|
|
info['vis_dict'] = self.vis_dict
|
|
info['keep_top_k'] = self.keep_top_k
|
|
info['epoch'] = self.epoch
|
|
torch.save(info, self.checkpoint_name)
|
|
print('Save checkpoint to', self.checkpoint_name)
|
|
|
|
def load_checkpoint(self):
|
|
if not os.path.exists(self.checkpoint_name):
|
|
return False
|
|
info = torch.load(self.checkpoint_name)
|
|
self.candidates = info['candidates']
|
|
self.vis_dict = info['vis_dict']
|
|
self.keep_top_k = info['keep_top_k']
|
|
self.epoch = info['epoch']
|
|
print('Load checkpoint from', self.checkpoint_name)
|
|
return True
|
|
|
|
def legal(self, cand):
|
|
assert isinstance(cand,tuple) and len(cand)==len(self.states)
|
|
if cand in self.vis_dict:
|
|
return False
|
|
|
|
if self.flops_limit is not None:
|
|
net = self.model.module.backbone if self.is_distributed else self.model.backbone
|
|
inp = (1, 3, 224, 224)
|
|
flops, params = profile(net, inp, extra_args={'paths': list(cand)})
|
|
flops = flops/1e6
|
|
print('flops:',flops)
|
|
if flops>self.flops_limit:
|
|
return False
|
|
|
|
return True
|
|
|
|
def update_top_k(self, candidates, *, k, key, reverse=False):
|
|
assert k in self.keep_top_k
|
|
# print('select ......')
|
|
t = self.keep_top_k[k]
|
|
t += candidates
|
|
t.sort(key=key,reverse=reverse)
|
|
self.keep_top_k[k]=t[:k]
|
|
|
|
def eval_candidates(self, train_loader, val_loader):
|
|
for cand in self.candidates:
|
|
t0 = time.time()
|
|
|
|
# load back supernet state dict
|
|
self.model.load_state_dict(self.supernet_state_dict)
|
|
# bn_statistic
|
|
model = bn_statistic(self.model, list(cand), train_loader)
|
|
# fitness
|
|
evals = fitness(cfg, model, list(cand), val_loader)
|
|
|
|
if is_main_process():
|
|
acc = evals[0].results['bbox']['AP']
|
|
self.vis_dict[cand] = acc
|
|
print('candiate ', cand)
|
|
print('time: {}s'.format(time.time() - t0))
|
|
print('acc ', acc)
|
|
|
|
def stack_random_cand(self, random_func, *, batchsize=10):
|
|
while True:
|
|
cands = [random_func() for _ in range(batchsize)]
|
|
for cand in cands:
|
|
yield cand
|
|
|
|
def random_can(self, num):
|
|
# print('random select ........')
|
|
candidates = []
|
|
cand_iter = self.stack_random_cand(lambda:tuple(np.random.randint(i) for i in self.states))
|
|
while len(candidates)<num:
|
|
cand = next(cand_iter)
|
|
|
|
if not self.legal(cand):
|
|
continue
|
|
candidates.append(cand)
|
|
#print('random {}/{}'.format(len(candidates),num))
|
|
|
|
# print('random_num = {}'.format(len(candidates)))
|
|
return candidates
|
|
|
|
def get_mutation(self, k, mutation_num, m_prob):
|
|
assert k in self.keep_top_k
|
|
# print('mutation ......')
|
|
res = []
|
|
iter = 0
|
|
max_iters = mutation_num*10
|
|
|
|
def random_func():
|
|
cand = list(choice(self.keep_top_k[k]))
|
|
for i in range(len(self.states)):
|
|
if np.random.random_sample()<m_prob:
|
|
cand[i] = np.random.randint(self.states[i])
|
|
return tuple(cand)
|
|
|
|
cand_iter = self.stack_random_cand(random_func)
|
|
while len(res)<mutation_num and max_iters>0:
|
|
cand = next(cand_iter)
|
|
if not self.legal(cand):
|
|
continue
|
|
res.append(cand)
|
|
#print('mutation {}/{}'.format(len(res),mutation_num))
|
|
max_iters-=1
|
|
|
|
# print('mutation_num = {}'.format(len(res)))
|
|
return res
|
|
|
|
def get_crossover(self, k, crossover_num):
|
|
assert k in self.keep_top_k
|
|
# print('crossover ......')
|
|
res = []
|
|
iter = 0
|
|
max_iters = 10 * crossover_num
|
|
|
|
def random_func():
|
|
p1=choice(self.keep_top_k[k])
|
|
p2=choice(self.keep_top_k[k])
|
|
return tuple(choice([i,j]) for i,j in zip(p1,p2))
|
|
|
|
cand_iter = self.stack_random_cand(random_func)
|
|
while len(res)<crossover_num and max_iters>0:
|
|
cand = next(cand_iter)
|
|
if not self.legal(cand):
|
|
continue
|
|
res.append(cand)
|
|
#print('crossover {}/{}'.format(len(res),crossover_num))
|
|
max_iters-=1
|
|
|
|
# print('crossover_num = {}'.format(len(res)))
|
|
return res
|
|
|
|
def train(self, train_loader, val_loader):
|
|
logger = logging.getLogger("maskrcnn_benchmark.evolution")
|
|
|
|
if not self.load_checkpoint():
|
|
self.candidates = gather_candidates(self.random_can(self.population_num))
|
|
|
|
while self.epoch<self.max_epochs:
|
|
self.eval_candidates(train_loader, val_loader)
|
|
self.vis_dict = gather_stats(self.vis_dict)
|
|
|
|
self.update_top_k(self.candidates, k=self.select_num, key=lambda x:1-self.vis_dict[x])
|
|
self.update_top_k(self.candidates, k=50, key=lambda x:1-self.vis_dict[x])
|
|
|
|
if is_main_process():
|
|
logger.info('Epoch {} : top {} result'.format(self.epoch+1, len(self.keep_top_k[self.select_num])))
|
|
for i,cand in enumerate(self.keep_top_k[self.select_num]):
|
|
logger.info(' No.{} {} perf = {}'.format(i+1, cand, self.vis_dict[cand]))
|
|
|
|
mutation = gather_candidates(self.get_mutation(self.select_num, self.mutation_num, self.mutation_prob))
|
|
crossover = gather_candidates(self.get_crossover(self.select_num, self.crossover_num))
|
|
rand = gather_candidates(self.random_can(self.population_num - len(mutation) - len(crossover)))
|
|
|
|
self.candidates = mutation + crossover + rand
|
|
|
|
self.epoch+=1
|
|
self.save_checkpoint()
|
|
|
|
def save_candidates(self, cand, template):
|
|
paths = self.keep_top_k[self.select_num][cand-1]
|
|
|
|
with open(template, "r") as f:
|
|
super_cfg = load_cfg(f)
|
|
|
|
search_spaces = {}
|
|
for mix_ops in super_cfg.MODEL.BACKBONE.LAYER_SEARCH:
|
|
search_spaces[mix_ops] = super_cfg.MODEL.BACKBONE.LAYER_SEARCH[mix_ops]
|
|
search_layers = super_cfg.MODEL.BACKBONE.LAYER_SETUP
|
|
|
|
layer_setup = []
|
|
for i, layer in enumerate(search_layers):
|
|
name, setup = get_layer_name(layer, search_spaces)
|
|
if not isinstance(name, list):
|
|
name = [name]
|
|
name = name[paths[i]]
|
|
|
|
layer_setup.append("('{}', {})".format(name, str(setup)[1:-1]))
|
|
super_cfg.MODEL.BACKBONE.LAYER_SETUP = layer_setup
|
|
|
|
cand_cfg = _to_dict(super_cfg)
|
|
del cand_cfg['MODEL']['BACKBONE']['LAYER_SEARCH']
|
|
with open(os.path.join(self.cfg.OUTPUT_DIR, os.path.basename(template)).replace('.yaml','_cand{}.yaml'.format(cand)), 'w') as f:
|
|
f.writelines(safe_dump(cand_cfg))
|
|
|
|
super_weight = self.supernet_state_dict
|
|
cand_weight = OrderedDict()
|
|
cand_keys = ['layers.{}.ops.{}'.format(i, c) for i, c in enumerate(paths)]
|
|
|
|
for key, val in super_weight.items():
|
|
if 'ops' in key:
|
|
for ck in cand_keys:
|
|
if ck in key:
|
|
cand_weight[key.replace(ck,ck.split('.ops.')[0])] = val
|
|
else:
|
|
cand_weight[key] = val
|
|
|
|
torch.save({'model':cand_weight}, os.path.join(self.cfg.OUTPUT_DIR, 'init_cand{}.pth'.format(cand)))
|