diff --git a/docs/en/training.md b/docs/en/training.md index 027ab8a8..6e290944 100644 --- a/docs/en/training.md +++ b/docs/en/training.md @@ -29,6 +29,7 @@ CUDA_VISIBLE_DEVICES= python tools/train.py ${CONFIG_FILE} [ARGS] | `--gpu-ids` | int*N | **Deprecated, please use --gpu-id.** A list of GPU ids to use. Only applicable to non-distributed training. | | `--gpu-id` | int | The GPU id to use. Only applicable to non-distributed training. | | `--seed` | int | Random seed. | +| `--diff_seed` | bool | Whether or not set different seeds for different ranks. | | `--deterministic` | bool | Whether to set deterministic options for CUDNN backend. | | `--cfg-options` | str | Override some settings in the used config, the key-value pair in xxx=yyy format will be merged into the config file. If the value to be overwritten is a list, it should be of the form of either key="[a,b]" or key=a,b. The argument also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks are necessary and that no white space is allowed. | | `--launcher` | 'none', 'pytorch', 'slurm', 'mpi' | Options for job launcher. | diff --git a/tools/train.py b/tools/train.py index c2ed4c6f..7def527f 100755 --- a/tools/train.py +++ b/tools/train.py @@ -9,6 +9,7 @@ import warnings import mmcv import torch +import torch.distributed as dist from mmcv import Config, DictAction from mmcv.runner import get_dist_info, init_dist, set_random_seed from mmcv.utils import get_git_hash @@ -52,6 +53,10 @@ def parse_args(): help='id of gpu to use ' '(only applicable to non-distributed training)') parser.add_argument('--seed', type=int, default=None, help='Random seed.') + parser.add_argument( + '--diff_seed', + action='store_true', + help='Whether or not set different seeds for different ranks') parser.add_argument( '--deterministic', action='store_true', @@ -170,6 +175,7 @@ def main(): # set random seeds seed = init_random_seed(args.seed) + seed = seed + dist.get_rank() if args.diff_seed else seed logger.info(f'Set random seed to {seed}, ' f'deterministic: {args.deterministic}') set_random_seed(seed, deterministic=args.deterministic)