support finetuning from trained models

Summary: add a flag for supporting finetuning model from the trained weights, and it's very useful when performing across domain reid
This commit is contained in:
liaoxingyu 2020-09-28 17:10:10 +08:00
parent 3024cea3a3
commit 10cbaab155
3 changed files with 9 additions and 3 deletions

View File

@ -29,6 +29,8 @@ from .caviara import CAVIARa
from .viper import VIPeR
from .lpw import LPW
from .shinpuhkan import Shinpuhkan
from .wildtracker import WildTrackCrop
from .cuhk_sysu import cuhkSYSU
# Vehicle re-id datasets
from .veri import VeRi

View File

@ -44,6 +44,11 @@ def default_argument_parser():
"""
parser = argparse.ArgumentParser(description="fastreid Training")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--finetune",
action="store_true",
help="whether to attempt to finetune from the trained model",
)
parser.add_argument(
"--resume",
action="store_true",
@ -248,9 +253,6 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint).
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
# Reinitialize dataloader iter because when we update dataset person identity dict
# to resume training, DataLoader won't update this dictionary when using multiprocess
# because of the function scope.
if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1
# The checkpoint stores the training iteration that just finished, thus we start

View File

@ -40,6 +40,8 @@ def main(args):
return res
trainer = DefaultTrainer(cfg)
if args.finetune: Checkpointer(trainer.model).load(cfg.MODEL.WEIGHTS) # load trained model to funetune
trainer.resume_or_load(resume=args.resume)
return trainer.train()