mirror of https://github.com/JDAI-CV/fast-reid.git
fix bug in evaluator (#284)
Summary: Change `Trainer` to `DefaultTrainer` in `tools/train.py` for adapting to new version of `build_evaluator`.pull/299/head
parent
a25d8a6bc1
commit
fae128a4db
|
@ -5,8 +5,6 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
|
@ -14,15 +12,6 @@ sys.path.append('.')
|
||||||
from fastreid.config import get_cfg
|
from fastreid.config import get_cfg
|
||||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
||||||
from fastreid.utils.checkpoint import Checkpointer
|
from fastreid.utils.checkpoint import Checkpointer
|
||||||
from fastreid.evaluation import ReidEvaluator
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer(DefaultTrainer):
|
|
||||||
@classmethod
|
|
||||||
def build_evaluator(cls, cfg, num_query, output_folder=None):
|
|
||||||
if output_folder is None:
|
|
||||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
|
||||||
return ReidEvaluator(cfg, num_query)
|
|
||||||
|
|
||||||
|
|
||||||
def setup(args):
|
def setup(args):
|
||||||
|
@ -43,14 +32,14 @@ def main(args):
|
||||||
if args.eval_only:
|
if args.eval_only:
|
||||||
cfg.defrost()
|
cfg.defrost()
|
||||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||||
model = Trainer.build_model(cfg)
|
model = DefaultTrainer.build_model(cfg)
|
||||||
|
|
||||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||||
|
|
||||||
res = Trainer.test(cfg, model)
|
res = DefaultTrainer.test(cfg, model)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
trainer = Trainer(cfg)
|
trainer = DefaultTrainer(cfg)
|
||||||
trainer.resume_or_load(resume=args.resume)
|
trainer.resume_or_load(resume=args.resume)
|
||||||
return trainer.train()
|
return trainer.train()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue