mirror of https://github.com/YifanXu74/MQ-Det.git
63 lines
3.1 KiB
Python
63 lines
3.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import os
|
|
import sys
|
|
|
|
from maskrcnn_benchmark.utils.custom_util import _download_url_to_file
|
|
try:
|
|
# from torch.hub import _download_url_to_file
|
|
from torch.hub import urlparse
|
|
from torch.hub import HASH_REGEX
|
|
except ImportError:
|
|
# from torch.utils.model_zoo import _download_url_to_file
|
|
from torch.utils.model_zoo import urlparse
|
|
from torch.utils.model_zoo import HASH_REGEX
|
|
|
|
from maskrcnn_benchmark.utils.comm import is_main_process
|
|
from maskrcnn_benchmark.utils.comm import synchronize
|
|
|
|
|
|
# very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py
|
|
# but with a few improvements and modifications
|
|
def cache_url(url, model_dir='model', progress=True):
|
|
r"""Loads the Torch serialized object at the given URL.
|
|
If the object is already present in `model_dir`, it's deserialized and
|
|
returned. The filename part of the URL should follow the naming convention
|
|
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
|
digits of the SHA256 hash of the contents of the file. The hash is used to
|
|
ensure unique names and to verify the contents of the file.
|
|
The default value of `model_dir` is ``$TORCH_HOME/models`` where
|
|
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
|
|
overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
|
|
Args:
|
|
url (string): URL of the object to download
|
|
model_dir (string, optional): directory in which to save the object
|
|
progress (bool, optional): whether or not to display a progress bar to stderr
|
|
Example:
|
|
>>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
|
|
"""
|
|
if model_dir is None:
|
|
torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
|
|
model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
|
|
if not os.path.exists(model_dir):
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
parts = urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
if filename == "model_final.pkl":
|
|
# workaround as pre-trained Caffe2 models from Detectron have all the same filename
|
|
# so make the full path the filename by replacing / with _
|
|
filename = parts.path.replace("/", "_")
|
|
cached_file = os.path.join(model_dir, filename)
|
|
if not os.path.exists(cached_file):
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
hash_prefix = HASH_REGEX.search(filename)
|
|
if hash_prefix is not None:
|
|
hash_prefix = hash_prefix.group(1)
|
|
# workaround: Caffe2 models don't have a hash, but follow the R-50 convention,
|
|
# which matches the hash PyTorch uses. So we skip the hash matching
|
|
# if the hash_prefix is less than 6 characters
|
|
if len(hash_prefix) < 6:
|
|
hash_prefix = None
|
|
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
|
synchronize()
|
|
return cached_file
|