mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Enhancement] Option for different seeds on different ranks (#820)
This commit is contained in:
parent
f1609b50e9
commit
c25404b358
@ -29,6 +29,7 @@ CUDA_VISIBLE_DEVICES= python tools/train.py ${CONFIG_FILE} [ARGS]
|
||||
| `--gpu-ids` | int*N | **Deprecated, please use --gpu-id.** A list of GPU ids to use. Only applicable to non-distributed training. |
|
||||
| `--gpu-id` | int | The GPU id to use. Only applicable to non-distributed training. |
|
||||
| `--seed` | int | Random seed. |
|
||||
| `--diff_seed` | bool | Whether or not set different seeds for different ranks. |
|
||||
| `--deterministic` | bool | Whether to set deterministic options for CUDNN backend. |
|
||||
| `--cfg-options` | str | Override some settings in the used config, the key-value pair in xxx=yyy format will be merged into the config file. If the value to be overwritten is a list, it should be of the form of either key="[a,b]" or key=a,b. The argument also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks are necessary and that no white space is allowed. |
|
||||
| `--launcher` | 'none', 'pytorch', 'slurm', 'mpi' | Options for job launcher. |
|
||||
|
@ -9,6 +9,7 @@ import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.runner import get_dist_info, init_dist, set_random_seed
|
||||
from mmcv.utils import get_git_hash
|
||||
@ -52,6 +53,10 @@ def parse_args():
|
||||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='Random seed.')
|
||||
parser.add_argument(
|
||||
'--diff_seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
action='store_true',
|
||||
@ -170,6 +175,7 @@ def main():
|
||||
|
||||
# set random seeds
|
||||
seed = init_random_seed(args.seed)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
|
Loading…
x
Reference in New Issue
Block a user