assert torch!=1.12.0 for DDP training (#8621)

* assert torch!=1.12.0 for DDP training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/8624/head
Glenn Jocher 2022-07-18 15:05:58 +02:00 committed by GitHub
parent 51fb467b63
commit 9cf5fd5ac3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 12 deletions

View File

@ -9,8 +9,8 @@ Pillow>=7.1.2
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
torch>=1.7.0,!=1.12.0 # https://github.com/ultralytics/yolov5/issues/8395
torchvision>=0.8.1,!=0.13.0 # https://github.com/ultralytics/yolov5/issues/8395
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012

View File

@ -27,7 +27,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from tqdm import tqdm
@ -46,15 +45,15 @@ from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
torch_distributed_zero_first)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
@ -248,10 +247,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# DDP mode
if cuda and RANK != -1:
if check_version(torch.__version__, '1.11.0'):
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
model = smart_DDP(model)
# Model attributes
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)

View File

@ -17,8 +17,13 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.general import LOGGER, colorstr, file_date, git_describe
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
try:
import thop # for FLOPs computation
@ -29,6 +34,17 @@ except ImportError:
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
def smart_DDP(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
if check_version(torch.__version__, '1.11.0'):
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
# Decorator to make all processes in distributed training wait for each local_master to do something