mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
generator seed fix for DDP mAP drop (#9545)
* Try to fix DDP mAP drop by setting generator's seed to RANK * Fix default activation bug * Update dataloaders.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c8e52304cf
commit
f11a8a62d2
@ -40,13 +40,13 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
|
|||||||
|
|
||||||
class Conv(nn.Module):
|
class Conv(nn.Module):
|
||||||
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
||||||
act = nn.SiLU() # default activation
|
default_act = nn.SiLU() # default activation
|
||||||
|
|
||||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
||||||
self.bn = nn.BatchNorm2d(c2)
|
self.bn = nn.BatchNorm2d(c2)
|
||||||
self.act = self.act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.act(self.bn(self.conv(x)))
|
return self.act(self.bn(self.conv(x)))
|
||||||
|
@ -301,7 +301,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|||||||
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
||||||
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
||||||
if act:
|
if act:
|
||||||
Conv.act = eval(act) # redefine default activation, i.e. Conv.act = nn.SiLU()
|
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||||
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
||||||
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
||||||
|
@ -40,6 +40,7 @@ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp',
|
|||||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
||||||
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
|
RANK = int(os.getenv('RANK', -1))
|
||||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||||
|
|
||||||
# Get orientation exif tag
|
# Get orientation exif tag
|
||||||
@ -139,7 +140,7 @@ def create_dataloader(path,
|
|||||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(0)
|
generator.manual_seed(6148914691236517205 + RANK)
|
||||||
return loader(dataset,
|
return loader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=shuffle and sampler is None,
|
shuffle=shuffle and sampler is None,
|
||||||
@ -1169,7 +1170,7 @@ def create_classification_dataloader(path,
|
|||||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
||||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(0)
|
generator.manual_seed(6148914691236517205 + RANK)
|
||||||
return InfiniteDataLoader(dataset,
|
return InfiniteDataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=shuffle and sampler is None,
|
shuffle=shuffle and sampler is None,
|
||||||
|
@ -17,6 +17,8 @@ from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
|
|||||||
from ..torch_utils import torch_distributed_zero_first
|
from ..torch_utils import torch_distributed_zero_first
|
||||||
from .augmentations import mixup, random_perspective
|
from .augmentations import mixup, random_perspective
|
||||||
|
|
||||||
|
RANK = int(os.getenv('RANK', -1))
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(path,
|
def create_dataloader(path,
|
||||||
imgsz,
|
imgsz,
|
||||||
@ -61,8 +63,8 @@ def create_dataloader(path,
|
|||||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||||
# generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
# generator.manual_seed(0)
|
generator.manual_seed(6148914691236517205 + RANK)
|
||||||
return loader(
|
return loader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -72,7 +74,7 @@ def create_dataloader(path,
|
|||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
|
collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
# generator=generator,
|
generator=generator,
|
||||||
), dataset
|
), dataset
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user