From 10cbaab155b065d0c823f36c21e7c78f8d2bcf47 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 28 Sep 2020 17:10:10 +0800 Subject: [PATCH] support finetuning from trained models Summary: add a flag for supporting finetuning model from the trained weights, and it's very useful when performing across domain reid --- fastreid/data/datasets/__init__.py | 2 ++ fastreid/engine/defaults.py | 8 +++++--- tools/train_net.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/fastreid/data/datasets/__init__.py b/fastreid/data/datasets/__init__.py index f7b8ccd..2977c51 100644 --- a/fastreid/data/datasets/__init__.py +++ b/fastreid/data/datasets/__init__.py @@ -29,6 +29,8 @@ from .caviara import CAVIARa from .viper import VIPeR from .lpw import LPW from .shinpuhkan import Shinpuhkan +from .wildtracker import WildTrackCrop +from .cuhk_sysu import cuhkSYSU # Vehicle re-id datasets from .veri import VeRi diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 3d53c83..775b2d9 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -44,6 +44,11 @@ def default_argument_parser(): """ parser = argparse.ArgumentParser(description="fastreid Training") parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--finetune", + action="store_true", + help="whether to attempt to finetune from the trained model", + ) parser.add_argument( "--resume", action="store_true", @@ -248,9 +253,6 @@ class DefaultTrainer(SimpleTrainer): # at the next iteration (or iter zero if there's no checkpoint). checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) - # Reinitialize dataloader iter because when we update dataset person identity dict - # to resume training, DataLoader won't update this dictionary when using multiprocess - # because of the function scope. if resume and self.checkpointer.has_checkpoint(): self.start_iter = checkpoint.get("iteration", -1) + 1 # The checkpoint stores the training iteration that just finished, thus we start diff --git a/tools/train_net.py b/tools/train_net.py index 4f80116..bbf9dd2 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -40,6 +40,8 @@ def main(args): return res trainer = DefaultTrainer(cfg) + if args.finetune: Checkpointer(trainer.model).load(cfg.MODEL.WEIGHTS) # load trained model to funetune + trainer.resume_or_load(resume=args.resume) return trainer.train()