# encoding: utf-8 """ @author: sherlock @contact: sherlockliao01@gmail.com """ import argparse import sys from bisect import bisect_right from torch.backends import cudnn sys.path.append('.') from config import cfg from data import get_data_bunch from engine.trainer import do_train from layers import make_loss from modeling import build_model from utils.logger import Logger from fastai.vision import * def train(cfg, log_path): # prepare dataset data_bunch, test_labels, num_query = get_data_bunch(cfg) # prepare model model = build_model(cfg, data_bunch.c) opt_func = partial(torch.optim.Adam) def warmup_multistep(start: float, end: float, pct: float) -> float: warmup_factor = 1 gamma = cfg.SOLVER.GAMMA milestones = [1.0 * s / cfg.SOLVER.MAX_EPOCHS for s in cfg.SOLVER.STEPS] warmup_iter = 1.0 * cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_EPOCHS if pct < warmup_iter: alpha = pct / warmup_iter warmup_factor = cfg.SOLVER.WARMUP_FACTOR * (1 - alpha) + alpha return start * warmup_factor * gamma ** bisect_right(milestones, pct) lr_sched = Scheduler((cfg.SOLVER.BASE_LR, 0), cfg.SOLVER.MAX_EPOCHS, warmup_multistep) loss_func = make_loss(cfg) do_train( cfg, log_path, model, data_bunch, test_labels, opt_func, lr_sched, loss_func, num_query ) def main(): parser = argparse.ArgumentParser(description="ReID Baseline Training") parser.add_argument( "--config_file", default="", help="path to config file", type=str ) parser.add_argument("opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) args = parser.parse_args() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 if args.config_file != "": cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() log_path = Path(os.path.join(cfg.OUTPUT_DIR, 'log.txt')) with log_path.open('w') as f: f.write('{}\n'.format(args)) print(args) if args.config_file != "": print("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, 'r') as cf: config_str = "\n" + cf.read() print(config_str) print("Running with config:\n{}".format(cfg)) with log_path.open('a') as f: f.write('{}\n'.format(cfg)) cudnn.benchmark = True train(cfg, log_path) if __name__ == '__main__': main()