bugfix for `plain_train_net.py` and lr scheduler step ()

pull/504/head
Sherlock Liao 2021-05-11 15:46:17 +08:00 committed by GitHub
parent 46b0681313
commit ff8a958fff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 22 deletions

View File

@ -24,7 +24,8 @@ Support many tasks beyond reid, such image retrieval and face recognition. See [
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
We write a [fastreid intro](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2020/05/29/fastreid.html)
and [fastreid v1.0](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2021/04/28/fastreid-v1.html) about this toolbox.
## Changelog

View File

@ -257,7 +257,7 @@ class LRScheduler(HookBase):
def after_epoch(self):
next_iter = self.trainer.iter + 1
next_epoch = self.trainer.epoch + 1
if next_iter > self.trainer.warmup_iters and next_epoch >= self.trainer.delay_epochs:
if next_iter > self.trainer.warmup_iters and next_epoch > self.trainer.delay_epochs:
self._scheduler["lr_sched"].step()

View File

@ -16,6 +16,7 @@ sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.data import build_reid_test_loader, build_reid_train_loader
from fastreid.evaluation.testing import flatten_results_dict
from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.modeling import build_model
from fastreid.solver import build_lr_scheduler, build_optimizer
@ -33,7 +34,7 @@ logger = logging.getLogger("fastreid")
def get_evaluator(cfg, dataset_name, output_dir=None):
data_loader, num_query = build_reid_test_loader(cfg, dataset_name)
data_loader, num_query = build_reid_test_loader(cfg, dataset_name=dataset_name)
return data_loader, ReidEvaluator(cfg, num_query, output_dir)
@ -49,24 +50,28 @@ def do_test(cfg, model):
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(
results, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results
)
print_csv_format(results)
if comm.is_main_process():
assert isinstance(
results, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
results_i['dataset'] = dataset_name
print_csv_format(results_i)
if len(results) == 1: results = list(results.values())[0]
if len(results) == 1:
results = list(results.values())[0]
return results
def do_train(cfg, model, resume=False):
data_loader = build_reid_train_loader(cfg)
data_loader_iter = iter(data_loader)
model.train()
optimizer = build_optimizer(cfg, model)
@ -78,7 +83,7 @@ def do_train(cfg, model, resume=False):
model,
cfg.OUTPUT_DIR,
save_to_disk=comm.is_main_process(),
optimizer=optimizer
optimizer=optimizer,
**scheduler
)
@ -93,6 +98,10 @@ def do_train(cfg, model, resume=False):
delay_epochs = cfg.SOLVER.DELAY_EPOCHS
periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_epoch)
if len(cfg.DATASETS.TESTS) == 1:
metric_name = "metric"
else:
metric_name = cfg.DATASETS.TESTS[0] + "/metric"
writers = (
[
@ -111,7 +120,8 @@ def do_train(cfg, model, resume=False):
with EventStorage(start_iter) as storage:
for epoch in range(start_epoch, max_epoch):
storage.epoch = epoch
for data, _ in zip(data_loader, range(iters_per_epoch)):
for _ in range(iters_per_epoch):
data = next(data_loader_iter)
storage.iter = iteration
loss_dict = model(data)
@ -128,9 +138,9 @@ def do_train(cfg, model, resume=False):
optimizer.step()
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
if iteration - start_iter > 5 and (
(iteration + 1) % 200 == 0 or iteration == max_iter - 1
):
if iteration - start_iter > 5 and \
((iteration + 1) % 200 == 0 or iteration == max_iter - 1) and \
((iteration + 1) % iters_per_epoch != 0):
for writer in writers:
writer.write()
@ -143,18 +153,22 @@ def do_train(cfg, model, resume=False):
for writer in writers:
writer.write()
if iteration > warmup_iters and (epoch + 1) >= delay_epochs:
if iteration > warmup_iters and (epoch + 1) > delay_epochs:
scheduler["lr_sched"].step()
if (
cfg.TEST.EVAL_PERIOD > 0
and (epoch + 1) % cfg.TEST.EVAL_PERIOD == 0
and epoch != max_iter - 1
and iteration != max_iter - 1
):
do_test(cfg, model)
results = do_test(cfg, model)
# Compared to "train_net.py", the test results are not dumped to EventStorage
else:
results = {}
flatten_results = flatten_results_dict(results)
periodic_checkpointer.step(epoch)
metric_dict = dict(metric=flatten_results[metric_name] if metric_name in flatten_results else -1)
periodic_checkpointer.step(epoch, **metric_dict)
def setup(args):
@ -184,7 +198,9 @@ def main(args):
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(model, delay_allreduce=True)
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
do_train(cfg, model, resume=args.resume)
return do_test(cfg, model)