mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Consolidate init_seeds()
(#4849)
This commit is contained in:
parent
302a1b0bb0
commit
84bfa89236
@ -29,7 +29,6 @@ import yaml
|
|||||||
|
|
||||||
from utils.downloads import gsutil_getsize
|
from utils.downloads import gsutil_getsize
|
||||||
from utils.metrics import box_iou, fitness
|
from utils.metrics import box_iou, fitness
|
||||||
from utils.torch_utils import init_torch_seeds
|
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||||
@ -91,10 +90,13 @@ def set_logging(rank=-1, verbose=True):
|
|||||||
|
|
||||||
|
|
||||||
def init_seeds(seed=0):
|
def init_seeds(seed=0):
|
||||||
# Initialize random number generator (RNG) seeds
|
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
||||||
|
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
init_torch_seeds(seed)
|
torch.manual_seed(seed)
|
||||||
|
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
||||||
|
|
||||||
|
|
||||||
def get_latest_run(search_dir='.'):
|
def get_latest_run(search_dir='.'):
|
||||||
|
@ -15,7 +15,6 @@ from copy import deepcopy
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -41,15 +40,6 @@ def torch_distributed_zero_first(local_rank: int):
|
|||||||
dist.barrier(device_ids=[0])
|
dist.barrier(device_ids=[0])
|
||||||
|
|
||||||
|
|
||||||
def init_torch_seeds(seed=0):
|
|
||||||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if seed == 0: # slower, more reproducible
|
|
||||||
cudnn.benchmark, cudnn.deterministic = False, True
|
|
||||||
else: # faster, less reproducible
|
|
||||||
cudnn.benchmark, cudnn.deterministic = True, False
|
|
||||||
|
|
||||||
|
|
||||||
def date_modified(path=__file__):
|
def date_modified(path=__file__):
|
||||||
# return human-readable file modification date, i.e. '2021-3-26'
|
# return human-readable file modification date, i.e. '2021-3-26'
|
||||||
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
|
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user