Merge pull request #425 from JDAI-CV/multi-node

Summary: Add multiple machine training getting started docs.
Change multiple dataset evaluation logging mode, which will show the testing result of each dataset immediately.

Reviewed by: l1aoxingyu
This commit is contained in:
Xingyu Liao 2021-03-09 20:13:29 +08:00 committed by GitHub
commit 0cc9fb95a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 57 deletions

View File

@ -32,6 +32,26 @@ If you want to train model with 4 GPUs, you can run:
python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4 python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
``` ```
If you want to train model with multiple machines, you can run:
```
# machine 1
export GLOO_SOCKET_IFNAME=eth0
export NCCL_SOCKET_IFNAME=eth0
python3 tools/train_net.py --config-file configs/Market1501/bagtricks_R50.yml \
--num-gpus 4 --num-machines 2 --machine-rank 0 --dist-url tcp://ip:port
# machine 2
export GLOO_SOCKET_IFNAME=eth0
export NCCL_SOCKET_IFNAME=eth0
python3 tools/train_net.py --config-file configs/Market1501/bagtricks_R50.yml \
--num-gpus 4 --num-machines 2 --machine-rank 1 --dist-url tcp://ip:port
```
Make sure the dataset path and code are the same in different machines, and machines can communicate with each other.
To evaluate a model's performance, use To evaluate a model's performance, use
```bash ```bash

View File

@ -467,15 +467,18 @@ class DefaultTrainer(TrainerBase):
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED) results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
results[dataset_name] = results_i results[dataset_name] = results_i
if comm.is_main_process(): if comm.is_main_process():
assert isinstance( assert isinstance(
results, dict results, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format( ), "Evaluator must return a dict on the main process. Got {} instead.".format(
results results
) )
print_csv_format(results) logger.info("Evaluation results for {} in csv format:".format(dataset_name))
results_i['dataset'] = dataset_name
print_csv_format(results_i)
if len(results) == 1: results = list(results.values())[0] if len(results) == 1:
results = list(results.values())[0]
return results return results

View File

@ -360,19 +360,20 @@ class EvalHook(HookBase):
) )
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Remove extra memory cache of main process due to evaluation
torch.cuda.empty_cache()
def after_epoch(self):
next_epoch = self.trainer.epoch + 1
is_final = next_epoch == self.trainer.max_epoch
if is_final or (self._period > 0 and next_epoch % self._period == 0):
self._do_eval()
# Evaluation may take different time among workers. # Evaluation may take different time among workers.
# A barrier make them start the next iteration together. # A barrier make them start the next iteration together.
comm.synchronize() comm.synchronize()
def after_epoch(self):
next_epoch = self.trainer.epoch + 1
if self._period > 0 and next_epoch % self._period == 0:
self._do_eval()
def after_train(self): def after_train(self):
next_epoch = self.trainer.epoch + 1
# This condition is to prevent the eval from running after a failed training
if next_epoch % self._period != 0 and next_epoch >= self.trainer.max_epoch:
self._do_eval()
# func is likely a closure that holds reference to the trainer # func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end # therefore we clean it to avoid circular reference in the end
del self._func del self._func

View File

@ -6,6 +6,7 @@ from contextlib import contextmanager
import torch import torch
from fastreid.utils import comm
from fastreid.utils.logger import log_every_n_seconds from fastreid.utils.logger import log_every_n_seconds
@ -96,6 +97,7 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
Returns: Returns:
The return value of `evaluator.evaluate()` The return value of `evaluator.evaluate()`
""" """
num_devices = comm.get_world_size()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.dataset))) logger.info("Start inference on {} images".format(len(data_loader.dataset)))
@ -118,10 +120,11 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
inputs["images"] = inputs["images"].flip(dims=[3]) inputs["images"] = inputs["images"].flip(dims=[3])
flip_outputs = model(inputs) flip_outputs = model(inputs)
outputs = (outputs + flip_outputs) / 2 outputs = (outputs + flip_outputs) / 2
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(inputs, outputs) evaluator.process(inputs, outputs)
idx += 1
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_batch = total_compute_time / iters_after_start seconds_per_batch = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_batch > 30: if idx >= num_warmup * 2 or seconds_per_batch > 30:
@ -140,17 +143,18 @@ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
total_time_str = str(datetime.timedelta(seconds=total_time)) total_time_str = str(datetime.timedelta(seconds=total_time))
# NOTE this format is parsed by grep # NOTE this format is parsed by grep
logger.info( logger.info(
"Total inference time: {} ({:.6f} s / batch per device)".format( "Total inference time: {} ({:.6f} s / batch per device, on {} devices)".format(
total_time_str, total_time / (total - num_warmup) total_time_str, total_time / (total - num_warmup), num_devices
) )
) )
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
logger.info( logger.info(
"Total inference pure compute time: {} ({:.6f} s / batch per device)".format( "Total inference pure compute time: {} ({:.6f} s / batch per device, on {} devices)".format(
total_compute_time_str, total_compute_time / (total - num_warmup) total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
) )
) )
results = evaluator.evaluate() results = evaluator.evaluate()
# An evaluator may return None when not in main process. # An evaluator may return None when not in main process.
# Replace it by an empty dict instead to make it easier for downstream code to handle # Replace it by an empty dict instead to make it easier for downstream code to handle
if results is None: if results is None:

View File

@ -5,6 +5,7 @@
""" """
import copy import copy
import logging import logging
import itertools
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
@ -28,50 +29,54 @@ class ReidEvaluator(DatasetEvaluator):
self._num_query = num_query self._num_query = num_query
self._output_dir = output_dir self._output_dir = output_dir
self.features = [] self._cpu_device = torch.device('cpu')
self.pids = []
self.camids = [] self._predictions = []
def reset(self): def reset(self):
self.features = [] self._predictions = []
self.pids = []
self.camids = []
def process(self, inputs, outputs): def process(self, inputs, outputs):
self.pids.extend(inputs["targets"]) prediction = {
self.camids.extend(inputs["camids"]) 'feats': outputs.to(self._cpu_device, torch.float32),
self.features.append(outputs.cpu()) 'pids': inputs['targets'].to(self._cpu_device),
'camids': inputs['camids'].to(self._cpu_device)
}
self._predictions.append(prediction)
def evaluate(self): def evaluate(self):
if comm.get_world_size() > 1: if comm.get_world_size() > 1:
comm.synchronize() comm.synchronize()
features = comm.gather(self.features) predictions = comm.gather(self._predictions, dst=0)
features = sum(features, []) predictions = list(itertools.chain(*predictions))
pids = comm.gather(self.pids) if not comm.is_main_process():
pids = sum(pids, []) return {}
camids = comm.gather(self.camids)
camids = sum(camids, [])
# fmt: off
if not comm.is_main_process(): return {}
# fmt: on
else: else:
features = self.features predictions = self._predictions
pids = self.pids
camids = self.camids features = []
pids = []
camids = []
for prediction in predictions:
features.append(prediction['feats'])
pids.append(prediction['pids'])
camids.append(prediction['camids'])
features = torch.cat(features, dim=0) features = torch.cat(features, dim=0)
pids = torch.cat(pids, dim=0).numpy()
camids = torch.cat(camids, dim=0).numpy()
# query feature, person ids and camera ids # query feature, person ids and camera ids
query_features = features[:self._num_query] query_features = features[:self._num_query]
query_pids = np.asarray(pids[:self._num_query]) query_pids = pids[:self._num_query]
query_camids = np.asarray(camids[:self._num_query]) query_camids = camids[:self._num_query]
# gallery features, person ids and camera ids # gallery features, person ids and camera ids
gallery_features = features[self._num_query:] gallery_features = features[self._num_query:]
gallery_pids = np.asarray(pids[self._num_query:]) gallery_pids = pids[self._num_query:]
gallery_camids = np.asarray(camids[self._num_query:]) gallery_camids = camids[self._num_query:]
self._results = OrderedDict() self._results = OrderedDict()

View File

@ -8,23 +8,21 @@ import numpy as np
from tabulate import tabulate from tabulate import tabulate
from termcolor import colored from termcolor import colored
logger = logging.getLogger(__name__)
def print_csv_format(results): def print_csv_format(results):
""" """
Print main metrics in a format similar to Detectron, Print main metrics in a format similar to Detectron2,
so that they are easy to copypaste into a spreadsheet. so that they are easy to copypaste into a spreadsheet.
Args: Args:
results (OrderedDict[dict]): task_name -> {metric -> score} results (OrderedDict): {metric -> score}
""" """
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed # unordered results cannot be properly printed
task = list(results.keys())[0] assert isinstance(results, OrderedDict) or not len(results), results
metrics = ["Datasets"] + [k for k in results[task]] logger = logging.getLogger(__name__)
csv_results = [] dataset_name = results.pop('dataset')
for task, res in results.items(): metrics = ["Dataset"] + [k for k in results]
csv_results.append((task, *list(res.values()))) csv_results = [(dataset_name, *list(results.values()))]
# tabulate it # tabulate it
table = tabulate( table = tabulate(