support finetuning from trained models

Summary: add a flag for supporting finetuning model from the trained weights, and it's very useful when performing across domain reid
This commit is contained in:
liaoxingyu 2020-09-28 17:10:10 +08:00
parent 3024cea3a3
commit 10cbaab155
3 changed files with 9 additions and 3 deletions

View File

@ -29,6 +29,8 @@ from .caviara import CAVIARa
from .viper import VIPeR from .viper import VIPeR
from .lpw import LPW from .lpw import LPW
from .shinpuhkan import Shinpuhkan from .shinpuhkan import Shinpuhkan
from .wildtracker import WildTrackCrop
from .cuhk_sysu import cuhkSYSU
# Vehicle re-id datasets # Vehicle re-id datasets
from .veri import VeRi from .veri import VeRi

View File

@ -44,6 +44,11 @@ def default_argument_parser():
""" """
parser = argparse.ArgumentParser(description="fastreid Training") parser = argparse.ArgumentParser(description="fastreid Training")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--finetune",
action="store_true",
help="whether to attempt to finetune from the trained model",
)
parser.add_argument( parser.add_argument(
"--resume", "--resume",
action="store_true", action="store_true",
@ -248,9 +253,6 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint). # at the next iteration (or iter zero if there's no checkpoint).
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
# Reinitialize dataloader iter because when we update dataset person identity dict
# to resume training, DataLoader won't update this dictionary when using multiprocess
# because of the function scope.
if resume and self.checkpointer.has_checkpoint(): if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1 self.start_iter = checkpoint.get("iteration", -1) + 1
# The checkpoint stores the training iteration that just finished, thus we start # The checkpoint stores the training iteration that just finished, thus we start

View File

@ -40,6 +40,8 @@ def main(args):
return res return res
trainer = DefaultTrainer(cfg) trainer = DefaultTrainer(cfg)
if args.finetune: Checkpointer(trainer.model).load(cfg.MODEL.WEIGHTS) # load trained model to funetune
trainer.resume_or_load(resume=args.resume) trainer.resume_or_load(resume=args.resume)
return trainer.train() return trainer.train()