diff --git a/fastreid/utils/checkpoint.py b/fastreid/utils/checkpoint.py index b31c4c0..d40d545 100644 --- a/fastreid/utils/checkpoint.py +++ b/fastreid/utils/checkpoint.py @@ -11,12 +11,17 @@ from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable import numpy as np import torch import torch.nn as nn -from apex.parallel import DistributedDataParallel from termcolor import colored from torch.nn.parallel import DataParallel from fastreid.utils.file_io import PathManager +try: + from apex.parallel import DistributedDataParallel +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to" + "train with DDP") + class _IncompatibleKeys( NamedTuple(