generator seed fix for DDP mAP drop (#9545)
* Try to fix DDP mAP drop by setting generator's seed to RANK * Fix default activation bug * Update dataloaders.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9573/head
parent
c8e52304cf
commit
f11a8a62d2
|
@ -40,13 +40,13 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
|
|||
|
||||
class Conv(nn.Module):
|
||||
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
||||
act = nn.SiLU() # default activation
|
||||
default_act = nn.SiLU() # default activation
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
||||
self.bn = nn.BatchNorm2d(c2)
|
||||
self.act = self.act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
|
|
@ -301,7 +301,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|||
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
||||
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
||||
if act:
|
||||
Conv.act = eval(act) # redefine default activation, i.e. Conv.act = nn.SiLU()
|
||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
||||
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
||||
|
|
|
@ -40,6 +40,7 @@ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp',
|
|||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
||||
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||
|
||||
# Get orientation exif tag
|
||||
|
@ -139,7 +140,7 @@ def create_dataloader(path,
|
|||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(0)
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return loader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
|
@ -1169,7 +1170,7 @@ def create_classification_dataloader(path,
|
|||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(0)
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return InfiniteDataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
|
|
|
@ -17,6 +17,8 @@ from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
|
|||
from ..torch_utils import torch_distributed_zero_first
|
||||
from .augmentations import mixup, random_perspective
|
||||
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
|
||||
|
||||
def create_dataloader(path,
|
||||
imgsz,
|
||||
|
@ -61,8 +63,8 @@ def create_dataloader(path,
|
|||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
# generator = torch.Generator()
|
||||
# generator.manual_seed(0)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return loader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
|
@ -72,7 +74,7 @@ def create_dataloader(path,
|
|||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
|
||||
worker_init_fn=seed_worker,
|
||||
# generator=generator,
|
||||
generator=generator,
|
||||
), dataset
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue