assert torch!=1.12.0 for DDP training (#8621)
* assert torch!=1.12.0 for DDP training * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/8624/head
parent
51fb467b63
commit
9cf5fd5ac3
|
@ -9,8 +9,8 @@ Pillow>=7.1.2
|
|||
PyYAML>=5.3.1
|
||||
requests>=2.23.0
|
||||
scipy>=1.4.1
|
||||
torch>=1.7.0,!=1.12.0 # https://github.com/ultralytics/yolov5/issues/8395
|
||||
torchvision>=0.8.1,!=0.13.0 # https://github.com/ultralytics/yolov5/issues/8395
|
||||
torch>=1.7.0
|
||||
torchvision>=0.8.1
|
||||
tqdm>=4.64.0
|
||||
protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012
|
||||
|
||||
|
|
14
train.py
14
train.py
|
@ -27,7 +27,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import lr_scheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -46,15 +45,15 @@ from utils.callbacks import Callbacks
|
|||
from utils.dataloaders import create_dataloader
|
||||
from utils.downloads import attempt_download
|
||||
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
|
||||
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
|
||||
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
|
||||
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
|
||||
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
|
||||
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
|
||||
one_cycle, print_args, print_mutation, strip_optimizer)
|
||||
from utils.loggers import Loggers
|
||||
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
||||
from utils.loss import ComputeLoss
|
||||
from utils.metrics import fitness
|
||||
from utils.plots import plot_evolve, plot_labels
|
||||
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
|
||||
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
|
||||
torch_distributed_zero_first)
|
||||
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
|
@ -248,10 +247,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
|
||||
# DDP mode
|
||||
if cuda and RANK != -1:
|
||||
if check_version(torch.__version__, '1.11.0'):
|
||||
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
||||
else:
|
||||
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||
model = smart_DDP(model)
|
||||
|
||||
# Model attributes
|
||||
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
|
|
|
@ -17,8 +17,13 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from utils.general import LOGGER, colorstr, file_date, git_describe
|
||||
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
||||
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||
|
||||
try:
|
||||
import thop # for FLOPs computation
|
||||
|
@ -29,6 +34,17 @@ except ImportError:
|
|||
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
||||
|
||||
|
||||
def smart_DDP(model):
|
||||
# Model DDP creation with checks
|
||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
||||
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
||||
if check_version(torch.__version__, '1.11.0'):
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
||||
else:
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
# Decorator to make all processes in distributed training wait for each local_master to do something
|
||||
|
|
Loading…
Reference in New Issue