mirror of https://github.com/JDAI-CV/fast-reid.git
bugfix for `plain_train_net.py` and lr scheduler step (#484)
parent
46b0681313
commit
ff8a958fff
|
@ -24,7 +24,8 @@ Support many tasks beyond reid, such image retrieval and face recognition. See [
|
|||
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
|
||||
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
|
||||
|
||||
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
|
||||
We write a [fastreid intro](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2020/05/29/fastreid.html)
|
||||
and [fastreid v1.0](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2021/04/28/fastreid-v1.html) about this toolbox.
|
||||
|
||||
## Changelog
|
||||
|
||||
|
|
|
@ -257,7 +257,7 @@ class LRScheduler(HookBase):
|
|||
def after_epoch(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
if next_iter > self.trainer.warmup_iters and next_epoch >= self.trainer.delay_epochs:
|
||||
if next_iter > self.trainer.warmup_iters and next_epoch > self.trainer.delay_epochs:
|
||||
self._scheduler["lr_sched"].step()
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ sys.path.append('.')
|
|||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.modeling import build_model
|
||||
from fastreid.solver import build_lr_scheduler, build_optimizer
|
||||
|
@ -33,7 +34,7 @@ logger = logging.getLogger("fastreid")
|
|||
|
||||
|
||||
def get_evaluator(cfg, dataset_name, output_dir=None):
|
||||
data_loader, num_query = build_reid_test_loader(cfg, dataset_name)
|
||||
data_loader, num_query = build_reid_test_loader(cfg, dataset_name=dataset_name)
|
||||
return data_loader, ReidEvaluator(cfg, num_query, output_dir)
|
||||
|
||||
|
||||
|
@ -49,24 +50,28 @@ def do_test(cfg, model):
|
|||
)
|
||||
results[dataset_name] = {}
|
||||
continue
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED)
|
||||
results[dataset_name] = results_i
|
||||
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results
|
||||
)
|
||||
print_csv_format(results)
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results
|
||||
)
|
||||
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
||||
results_i['dataset'] = dataset_name
|
||||
print_csv_format(results_i)
|
||||
|
||||
if len(results) == 1: results = list(results.values())[0]
|
||||
if len(results) == 1:
|
||||
results = list(results.values())[0]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def do_train(cfg, model, resume=False):
|
||||
data_loader = build_reid_train_loader(cfg)
|
||||
data_loader_iter = iter(data_loader)
|
||||
|
||||
model.train()
|
||||
optimizer = build_optimizer(cfg, model)
|
||||
|
@ -78,7 +83,7 @@ def do_train(cfg, model, resume=False):
|
|||
model,
|
||||
cfg.OUTPUT_DIR,
|
||||
save_to_disk=comm.is_main_process(),
|
||||
optimizer=optimizer
|
||||
optimizer=optimizer,
|
||||
**scheduler
|
||||
)
|
||||
|
||||
|
@ -93,6 +98,10 @@ def do_train(cfg, model, resume=False):
|
|||
delay_epochs = cfg.SOLVER.DELAY_EPOCHS
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_epoch)
|
||||
if len(cfg.DATASETS.TESTS) == 1:
|
||||
metric_name = "metric"
|
||||
else:
|
||||
metric_name = cfg.DATASETS.TESTS[0] + "/metric"
|
||||
|
||||
writers = (
|
||||
[
|
||||
|
@ -111,7 +120,8 @@ def do_train(cfg, model, resume=False):
|
|||
with EventStorage(start_iter) as storage:
|
||||
for epoch in range(start_epoch, max_epoch):
|
||||
storage.epoch = epoch
|
||||
for data, _ in zip(data_loader, range(iters_per_epoch)):
|
||||
for _ in range(iters_per_epoch):
|
||||
data = next(data_loader_iter)
|
||||
storage.iter = iteration
|
||||
|
||||
loss_dict = model(data)
|
||||
|
@ -128,9 +138,9 @@ def do_train(cfg, model, resume=False):
|
|||
optimizer.step()
|
||||
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
|
||||
|
||||
if iteration - start_iter > 5 and (
|
||||
(iteration + 1) % 200 == 0 or iteration == max_iter - 1
|
||||
):
|
||||
if iteration - start_iter > 5 and \
|
||||
((iteration + 1) % 200 == 0 or iteration == max_iter - 1) and \
|
||||
((iteration + 1) % iters_per_epoch != 0):
|
||||
for writer in writers:
|
||||
writer.write()
|
||||
|
||||
|
@ -143,18 +153,22 @@ def do_train(cfg, model, resume=False):
|
|||
for writer in writers:
|
||||
writer.write()
|
||||
|
||||
if iteration > warmup_iters and (epoch + 1) >= delay_epochs:
|
||||
if iteration > warmup_iters and (epoch + 1) > delay_epochs:
|
||||
scheduler["lr_sched"].step()
|
||||
|
||||
if (
|
||||
cfg.TEST.EVAL_PERIOD > 0
|
||||
and (epoch + 1) % cfg.TEST.EVAL_PERIOD == 0
|
||||
and epoch != max_iter - 1
|
||||
and iteration != max_iter - 1
|
||||
):
|
||||
do_test(cfg, model)
|
||||
results = do_test(cfg, model)
|
||||
# Compared to "train_net.py", the test results are not dumped to EventStorage
|
||||
else:
|
||||
results = {}
|
||||
flatten_results = flatten_results_dict(results)
|
||||
|
||||
periodic_checkpointer.step(epoch)
|
||||
metric_dict = dict(metric=flatten_results[metric_name] if metric_name in flatten_results else -1)
|
||||
periodic_checkpointer.step(epoch, **metric_dict)
|
||||
|
||||
|
||||
def setup(args):
|
||||
|
@ -184,7 +198,9 @@ def main(args):
|
|||
|
||||
distributed = comm.get_world_size() > 1
|
||||
if distributed:
|
||||
model = DistributedDataParallel(model, delay_allreduce=True)
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||||
)
|
||||
|
||||
do_train(cfg, model, resume=args.resume)
|
||||
return do_test(cfg, model)
|
||||
|
|
Loading…
Reference in New Issue