mirror of https://github.com/YifanXu74/MQ-Det.git
128 lines
4.5 KiB
Python
128 lines
4.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import datetime
|
|
import logging
|
|
import time
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from maskrcnn_benchmark.utils.comm import get_world_size
|
|
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
|
|
|
|
|
|
def reduce_loss_dict(all_loss_dict):
|
|
"""
|
|
Reduce the loss dictionary from all processes so that process with rank
|
|
0 has the averaged results. Returns a dict with the same fields as
|
|
loss_dict, after reduction.
|
|
"""
|
|
world_size = get_world_size()
|
|
with torch.no_grad():
|
|
loss_names = []
|
|
all_losses = []
|
|
for loss_dict in all_loss_dict:
|
|
for k in sorted(loss_dict.keys()):
|
|
loss_names.append(k)
|
|
all_losses.append(loss_dict[k])
|
|
all_losses = torch.stack(all_losses, dim=0)
|
|
if world_size > 1:
|
|
dist.reduce(all_losses, dst=0)
|
|
if dist.get_rank() == 0:
|
|
# only main process gets accumulated, so only divide by
|
|
# world_size in this case
|
|
all_losses /= world_size
|
|
|
|
reduced_losses = {}
|
|
for k, v in zip(loss_names, all_losses):
|
|
if k not in reduced_losses:
|
|
reduced_losses[k] = v / len(all_loss_dict)
|
|
reduced_losses[k] += v / len(all_loss_dict)
|
|
|
|
return reduced_losses
|
|
|
|
|
|
def do_train(
|
|
model,
|
|
data_loader,
|
|
optimizer,
|
|
scheduler,
|
|
checkpointer,
|
|
device,
|
|
checkpoint_period,
|
|
arguments,
|
|
):
|
|
logger = logging.getLogger("maskrcnn_benchmark.trainer")
|
|
logger.info("Start training")
|
|
meters = MetricLogger(delimiter=" ")
|
|
max_iter = min(len(task_loader) for task_loader in data_loader)
|
|
start_iter = arguments["iteration"]
|
|
model.train()
|
|
start_training_time = time.time()
|
|
end = time.time()
|
|
for iteration, task_loader in enumerate(zip(*data_loader), start_iter):
|
|
data_time = time.time() - end
|
|
iteration = iteration + 1
|
|
arguments["iteration"] = iteration
|
|
|
|
all_task_loss_dict = []
|
|
for task, (images, targets, _) in enumerate(task_loader, 1):
|
|
if all(len(target) < 1 for target in targets):
|
|
logger.warning('Sampled all negative batches, skip')
|
|
continue
|
|
|
|
images = images.to(device)
|
|
targets = [target.to(device) for target in targets]
|
|
|
|
loss_dict = model(images, targets, task)
|
|
all_task_loss_dict.append(loss_dict)
|
|
|
|
losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values())
|
|
|
|
# reduce losses over all GPUs for logging purposes
|
|
loss_dict_reduced = reduce_loss_dict(all_task_loss_dict)
|
|
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
|
meters.update(loss=losses_reduced, **loss_dict_reduced)
|
|
|
|
optimizer.zero_grad()
|
|
losses.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
batch_time = time.time() - end
|
|
end = time.time()
|
|
meters.update(time=batch_time, data=data_time)
|
|
|
|
eta_seconds = meters.time.global_avg * (max_iter - iteration)
|
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
|
|
if iteration % 20 == 0 or iteration == max_iter:
|
|
logger.info(
|
|
meters.delimiter.join(
|
|
[
|
|
"eta: {eta}",
|
|
"iter: {iter}",
|
|
"{meters}",
|
|
"lr: {lr:.6f}",
|
|
"max mem: {memory:.0f}",
|
|
]
|
|
).format(
|
|
eta=eta_string,
|
|
iter=iteration,
|
|
meters=str(meters),
|
|
lr=optimizer.param_groups[0]["lr"],
|
|
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
|
|
)
|
|
)
|
|
if iteration % checkpoint_period == 0:
|
|
checkpointer.save("model_{:07d}".format(iteration), **arguments)
|
|
if iteration == max_iter:
|
|
checkpointer.save("model_final", **arguments)
|
|
|
|
total_training_time = time.time() - start_training_time
|
|
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
|
logger.info(
|
|
"Total training time: {} ({:.4f} s / it)".format(
|
|
total_time_str, total_training_time / (max_iter)
|
|
)
|
|
)
|