mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add checkpoint helper (#93)
* [Feature] Add checkpoint helper * minor fix * add comments and fix format * add hub directory * add missing docstring
This commit is contained in:
parent
ac666711ab
commit
8b0e74d327
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@ -0,0 +1 @@
|
||||
include mmengine/hub/openmmlab.json mmengine/hub/deprecated.json mmengine/hub/mmcls.json
|
@ -7,4 +7,5 @@ from .fileio import *
|
||||
from .hooks import *
|
||||
from .logging import *
|
||||
from .registry import *
|
||||
from .runner import *
|
||||
from .utils import *
|
||||
|
6
mmengine/hub/deprecated.json
Normal file
6
mmengine/hub/deprecated.json
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"resnet50_caffe": "detectron/resnet50_caffe",
|
||||
"resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
|
||||
"resnet101_caffe": "detectron/resnet101_caffe",
|
||||
"resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
|
||||
}
|
59
mmengine/hub/mmcls.json
Normal file
59
mmengine/hub/mmcls.json
Normal file
@ -0,0 +1,59 @@
|
||||
{
|
||||
"vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
|
||||
"vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
|
||||
"vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
|
||||
"vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
|
||||
"vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
|
||||
"vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
|
||||
"vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
|
||||
"vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
|
||||
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
|
||||
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
|
||||
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
|
||||
"resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
|
||||
"resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
|
||||
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
|
||||
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
|
||||
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
|
||||
"resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
|
||||
"resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
|
||||
"resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
|
||||
"resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
|
||||
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
|
||||
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
|
||||
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
|
||||
"resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
|
||||
"resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
|
||||
"resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
|
||||
"shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
|
||||
"shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
|
||||
"mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
|
||||
"mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
|
||||
"mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
|
||||
"repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
|
||||
"repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
|
||||
"repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
|
||||
"repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
|
||||
"repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
|
||||
"repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
|
||||
"repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
|
||||
"repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
|
||||
"repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
|
||||
"repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
|
||||
"repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
|
||||
"repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
|
||||
"res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
|
||||
"res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
|
||||
"res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
|
||||
"swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
|
||||
"swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
|
||||
"swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
|
||||
"swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
|
||||
"t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
|
||||
"t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
|
||||
"t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
|
||||
"tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
|
||||
"vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
|
||||
"vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
|
||||
"vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
|
||||
}
|
50
mmengine/hub/openmmlab.json
Normal file
50
mmengine/hub/openmmlab.json
Normal file
@ -0,0 +1,50 @@
|
||||
{
|
||||
"vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
|
||||
"detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
|
||||
"detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
|
||||
"detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
|
||||
"detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
|
||||
"detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
|
||||
"resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
|
||||
"resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
|
||||
"resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
|
||||
"contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
|
||||
"detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
|
||||
"detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
|
||||
"jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
|
||||
"jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
|
||||
"jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
|
||||
"jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
|
||||
"jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
|
||||
"jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
|
||||
"msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
|
||||
"msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
|
||||
"msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
|
||||
"msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
|
||||
"msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
|
||||
"bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
|
||||
"kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
|
||||
"kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
|
||||
"res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
|
||||
"regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
|
||||
"regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
|
||||
"regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
|
||||
"regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
|
||||
"regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
|
||||
"regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
|
||||
"regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
|
||||
"regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
|
||||
"resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
|
||||
"resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
|
||||
"resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
|
||||
"mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
|
||||
"mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
|
||||
"mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
|
||||
"contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
|
||||
"contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
|
||||
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
|
||||
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
|
||||
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
|
||||
"darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
|
||||
"mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
|
||||
}
|
@ -1 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .checkpoint import (CheckpointLoader, get_deprecated_model_names,
|
||||
get_external_models, get_mmcls_models, get_state_dict,
|
||||
get_torchvision_models, load_checkpoint,
|
||||
load_state_dict, save_checkpoint, weights_to_cpu)
|
||||
|
||||
__all__ = [
|
||||
'load_state_dict', 'get_torchvision_models', 'get_external_models',
|
||||
'get_mmcls_models', 'get_deprecated_model_names', 'CheckpointLoader',
|
||||
'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'save_checkpoint'
|
||||
]
|
||||
|
697
mmengine/runner/checkpoint.py
Normal file
697
mmengine/runner/checkpoint.py
Normal file
@ -0,0 +1,697 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import io
|
||||
import os
|
||||
import os.path as osp
|
||||
import pkgutil
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import mmengine
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.fileio import load as load_file
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.utils import load_url, mkdir_or_exist
|
||||
|
||||
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
|
||||
# downloaded from Internet. If it is not set, as a workaround, using
|
||||
# `XDG_CACHE_HOME`` or `~/.cache` instead.
|
||||
# Note that `XDG_CACHE_HOME` defines the base directory relative to which
|
||||
# user-specific non-essential data files should be stored. If `XDG_CACHE_HOME`
|
||||
# is either not set or empty, a default equal to `~/.cache` should be used.
|
||||
ENV_MMENGINE_HOME = 'MMENGINE_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
|
||||
|
||||
def _get_mmengine_home():
|
||||
mmengine_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
ENV_MMENGINE_HOME,
|
||||
os.path.join(
|
||||
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine')))
|
||||
|
||||
mkdir_or_exist(mmengine_home)
|
||||
return mmengine_home
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
"""Load state_dict to a module.
|
||||
|
||||
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
||||
Default value for ``strict`` is set to ``False`` and the message for
|
||||
param mismatch will be shown even if strict is False.
|
||||
|
||||
Args:
|
||||
module (Module): Module that receives the state_dict.
|
||||
state_dict (OrderedDict): Weights.
|
||||
strict (bool): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
||||
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
||||
message. If not specified, print function will be used.
|
||||
"""
|
||||
unexpected_keys = []
|
||||
all_missing_keys = []
|
||||
err_msg = []
|
||||
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# use _load_from_state_dict to enable checkpoint version control
|
||||
def load(module, prefix=''):
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_model_wrapper(module):
|
||||
module = module.module
|
||||
local_metadata = {} if metadata is None else metadata.get(
|
||||
prefix[:-1], {})
|
||||
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||
all_missing_keys, unexpected_keys,
|
||||
err_msg)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(module)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
# ignore "num_batches_tracked" of BN layers
|
||||
missing_keys = [
|
||||
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
||||
]
|
||||
|
||||
if unexpected_keys:
|
||||
err_msg.append('unexpected key in source '
|
||||
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||
if missing_keys:
|
||||
err_msg.append(
|
||||
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if len(err_msg) > 0 and rank == 0:
|
||||
err_msg.insert(
|
||||
0, 'The model and loaded state dict do not match exactly\n')
|
||||
err_msg = '\n'.join(err_msg)
|
||||
if strict:
|
||||
raise RuntimeError(err_msg)
|
||||
elif logger is not None:
|
||||
logger.warning(err_msg)
|
||||
else:
|
||||
print(err_msg)
|
||||
|
||||
|
||||
def get_torchvision_models():
|
||||
model_urls = dict()
|
||||
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
||||
if ispkg:
|
||||
continue
|
||||
_zoo = import_module(f'torchvision.models.{name}')
|
||||
if hasattr(_zoo, 'model_urls'):
|
||||
_urls = getattr(_zoo, 'model_urls')
|
||||
model_urls.update(_urls)
|
||||
return model_urls
|
||||
|
||||
|
||||
def get_external_models():
|
||||
mmengine_home = _get_mmengine_home()
|
||||
default_json_path = osp.join(mmengine.__path__[0], 'hub/openmmlab.json')
|
||||
default_urls = load_file(default_json_path)
|
||||
assert isinstance(default_urls, dict)
|
||||
external_json_path = osp.join(mmengine_home, 'open_mmlab.json')
|
||||
if osp.exists(external_json_path):
|
||||
external_urls = load_file(external_json_path)
|
||||
assert isinstance(external_urls, dict)
|
||||
default_urls.update(external_urls)
|
||||
|
||||
return default_urls
|
||||
|
||||
|
||||
def get_mmcls_models():
|
||||
mmcls_json_path = osp.join(mmengine.__path__[0], 'hub/mmcls.json')
|
||||
mmcls_urls = load_file(mmcls_json_path)
|
||||
|
||||
return mmcls_urls
|
||||
|
||||
|
||||
def get_deprecated_model_names():
|
||||
deprecate_json_path = osp.join(mmengine.__path__[0], 'hub/deprecated.json')
|
||||
deprecate_urls = load_file(deprecate_json_path)
|
||||
assert isinstance(deprecate_urls, dict)
|
||||
|
||||
return deprecate_urls
|
||||
|
||||
|
||||
def _process_mmcls_checkpoint(checkpoint):
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
# Some checkpoints converted from 3rd-party repo don't
|
||||
# have the "state_dict" key.
|
||||
state_dict = checkpoint
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
new_state_dict[k[9:]] = v
|
||||
new_checkpoint = dict(state_dict=new_state_dict)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
"""A general checkpoint loader to manage all schemes."""
|
||||
|
||||
_schemes: Dict[str, Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def _register_scheme(cls, prefixes, loader, force=False):
|
||||
if isinstance(prefixes, str):
|
||||
prefixes = [prefixes]
|
||||
else:
|
||||
assert isinstance(prefixes, (list, tuple))
|
||||
for prefix in prefixes:
|
||||
if (prefix not in cls._schemes) or force:
|
||||
cls._schemes[prefix] = loader
|
||||
else:
|
||||
raise KeyError(
|
||||
f'{prefix} is already registered as a loader backend, '
|
||||
'add "force=True" if you want to override it')
|
||||
# sort, longer prefixes take priority
|
||||
cls._schemes = OrderedDict(
|
||||
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
|
||||
|
||||
@classmethod
|
||||
def register_scheme(cls, prefixes, loader=None, force=False):
|
||||
"""Register a loader to CheckpointLoader.
|
||||
|
||||
This method can be used as a normal class method or a decorator.
|
||||
|
||||
Args:
|
||||
prefixes (str or list[str] or tuple[str]):
|
||||
The prefix of the registered loader.
|
||||
loader (function, optional): The loader function to be registered.
|
||||
When this method is used as a decorator, loader is None.
|
||||
Defaults to None.
|
||||
force (bool, optional): Whether to override the loader
|
||||
if the prefix has already been registered. Defaults to False.
|
||||
"""
|
||||
|
||||
if loader is not None:
|
||||
cls._register_scheme(prefixes, loader, force=force)
|
||||
return
|
||||
|
||||
def _register(loader_cls):
|
||||
cls._register_scheme(prefixes, loader_cls, force=force)
|
||||
return loader_cls
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoint_loader(cls, path):
|
||||
"""Finds a loader that supports the given path. Falls back to the local
|
||||
loader if no other loader is found.
|
||||
|
||||
Args:
|
||||
path (str): checkpoint path
|
||||
|
||||
Returns:
|
||||
callable: checkpoint loader
|
||||
"""
|
||||
for p in cls._schemes:
|
||||
# use regular match to handle some cases that where the prefix of
|
||||
# loader has a prefix. For example, both 's3://path' and
|
||||
# 'open-mmlab:s3://path' should return `load_from_ceph`
|
||||
if re.match(p, path) is not None:
|
||||
return cls._schemes[p]
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, filename, map_location=None, logger=None):
|
||||
"""load checkpoint through URL scheme path.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file name with given prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
Default: None
|
||||
logger (:mod:`logging.Logger`, optional): The logger for message.
|
||||
Default: None
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
|
||||
checkpoint_loader = cls._get_checkpoint_loader(filename)
|
||||
class_name = checkpoint_loader.__name__
|
||||
mmengine.print_log(
|
||||
f'{class_name[10:]} loads checkpoint from path: {filename}',
|
||||
logger)
|
||||
return checkpoint_loader(filename, map_location)
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes='')
|
||||
def load_from_local(filename, map_location):
|
||||
"""load checkpoint by local file path.
|
||||
|
||||
Args:
|
||||
filename (str): local checkpoint file path
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
filename = osp.expanduser(filename)
|
||||
if not osp.isfile(filename):
|
||||
raise FileNotFoundError(f'{filename} can not be found.')
|
||||
checkpoint = torch.load(filename, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
|
||||
def load_from_http(filename, map_location=None, model_dir=None):
|
||||
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||
setting, this function only download checkpoint at local rank 0.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with modelzoo or
|
||||
torchvision prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
model_dir (string, optional): directory in which to save the object,
|
||||
Default: None
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
rank, world_size = get_dist_info()
|
||||
if rank == 0:
|
||||
checkpoint = load_url(
|
||||
filename, model_dir=model_dir, map_location=map_location)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
checkpoint = load_url(
|
||||
filename, model_dir=model_dir, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
||||
def load_from_pavi(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with pavi. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with pavi prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
Default: None
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
assert filename.startswith('pavi://'), \
|
||||
f'Expected filename startswith `pavi://`, but get {filename}'
|
||||
model_path = filename[7:]
|
||||
|
||||
try:
|
||||
from pavi import modelcloud
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install pavi to load checkpoint from modelcloud.')
|
||||
|
||||
model = modelcloud.get(model_path)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
downloaded_file = osp.join(tmp_dir, model.name)
|
||||
model.download(downloaded_file)
|
||||
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
|
||||
def load_from_ceph(filename, map_location=None, backend='petrel'):
|
||||
"""load checkpoint through the file path prefixed with s3. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with s3 prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
backend (str, optional): The storage backend type. Options are 'ceph',
|
||||
'petrel'. Default: 'petrel'.
|
||||
|
||||
.. warning::
|
||||
:class:`mmengine.fileio.file_client.CephBackend` will be deprecated,
|
||||
please use :class:`mmengine.fileio.file_client.PetrelBackend` instead.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
allowed_backends = ['ceph', 'petrel']
|
||||
if backend not in allowed_backends:
|
||||
raise ValueError(f'Load from Backend {backend} is not supported.')
|
||||
|
||||
if backend == 'ceph':
|
||||
warnings.warn(
|
||||
'CephBackend will be deprecated, please use PetrelBackend instead',
|
||||
DeprecationWarning)
|
||||
|
||||
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
|
||||
# will be chosen as default. If PetrelBackend can not be instantiated
|
||||
# successfully, the CephClient will be chosen.
|
||||
try:
|
||||
file_client = FileClient(backend=backend)
|
||||
except ImportError:
|
||||
allowed_backends.remove(backend)
|
||||
file_client = FileClient(backend=allowed_backends[0])
|
||||
|
||||
with io.BytesIO(file_client.get(filename)) as buffer:
|
||||
checkpoint = torch.load(buffer, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
|
||||
def load_from_torchvision(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with modelzoo or
|
||||
torchvision.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with modelzoo or
|
||||
torchvision prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
model_urls = get_torchvision_models()
|
||||
if filename.startswith('modelzoo://'):
|
||||
warnings.warn(
|
||||
'The URL scheme of "modelzoo://" is deprecated, please '
|
||||
'use "torchvision://" instead', DeprecationWarning)
|
||||
model_name = filename[11:]
|
||||
else:
|
||||
model_name = filename[14:]
|
||||
return load_from_http(model_urls[model_name], map_location=map_location)
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
||||
def load_from_openmmlab(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with open-mmlab or
|
||||
openmmlab.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with open-mmlab or
|
||||
openmmlab prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
Default: None
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
|
||||
model_urls = get_external_models()
|
||||
prefix_str = 'open-mmlab://'
|
||||
if filename.startswith(prefix_str):
|
||||
model_name = filename[13:]
|
||||
else:
|
||||
model_name = filename[12:]
|
||||
prefix_str = 'openmmlab://'
|
||||
|
||||
deprecated_urls = get_deprecated_model_names()
|
||||
if model_name in deprecated_urls:
|
||||
warnings.warn(
|
||||
f'{prefix_str}{model_name} is deprecated in favor '
|
||||
f'of {prefix_str}{deprecated_urls[model_name]}',
|
||||
DeprecationWarning)
|
||||
model_name = deprecated_urls[model_name]
|
||||
model_url = model_urls[model_name]
|
||||
# check if is url
|
||||
if model_url.startswith(('http://', 'https://')):
|
||||
checkpoint = load_from_http(model_url, map_location=map_location)
|
||||
else:
|
||||
filename = osp.join(_get_mmengine_home(), model_url)
|
||||
if not osp.isfile(filename):
|
||||
raise FileNotFoundError(f'{filename} can not be found.')
|
||||
checkpoint = torch.load(filename, map_location=map_location)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@CheckpointLoader.register_scheme(prefixes='mmcls://')
|
||||
def load_from_mmcls(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with mmcls.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with mmcls prefix
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
|
||||
model_urls = get_mmcls_models()
|
||||
model_name = filename[8:]
|
||||
checkpoint = load_from_http(
|
||||
model_urls[model_name], map_location=map_location)
|
||||
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def _load_checkpoint(filename, map_location=None, logger=None):
|
||||
"""Load checkpoint from somewhere (modelzoo, file, url).
|
||||
|
||||
Args:
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
Default: None.
|
||||
logger (:mod:`logging.Logger`, optional): The logger for error message.
|
||||
Default: None
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint. It can be either an
|
||||
OrderedDict storing model weights or a dict containing other
|
||||
information, which depends on the checkpoint.
|
||||
"""
|
||||
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
|
||||
|
||||
|
||||
def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
|
||||
"""Load partial pretrained model with specific prefix.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix of sub-module.
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
|
||||
checkpoint = _load_checkpoint(filename, map_location=map_location)
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
if not prefix.endswith('.'):
|
||||
prefix += '.'
|
||||
prefix_len = len(prefix)
|
||||
|
||||
state_dict = {
|
||||
k[prefix_len:]: v
|
||||
for k, v in state_dict.items() if k.startswith(prefix)
|
||||
}
|
||||
|
||||
assert state_dict, f'{prefix} is not in the pretrained model'
|
||||
return state_dict
|
||||
|
||||
|
||||
def _load_checkpoint_to_model(model,
|
||||
checkpoint,
|
||||
strict=False,
|
||||
logger=None,
|
||||
revise_keys=[(r'^module\.', '')]):
|
||||
|
||||
# get state_dict from checkpoint
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# strip prefix of state_dict
|
||||
metadata = getattr(state_dict, '_metadata', OrderedDict())
|
||||
for p, r in revise_keys:
|
||||
state_dict = OrderedDict(
|
||||
{re.sub(p, r, k): v
|
||||
for k, v in state_dict.items()})
|
||||
# Keep metadata in state_dict
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# load state_dict
|
||||
load_state_dict(model, state_dict, strict, logger)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(model,
|
||||
filename,
|
||||
map_location=None,
|
||||
strict=False,
|
||||
logger=None,
|
||||
revise_keys=[(r'^module\.', '')]):
|
||||
"""Load checkpoint from a file or URI.
|
||||
|
||||
Args:
|
||||
model (Module): Module to load checkpoint.
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str): Same as :func:`torch.load`.
|
||||
strict (bool): Whether to allow different params for the model and
|
||||
checkpoint.
|
||||
logger (:mod:`logging.Logger` or None): The logger for error message.
|
||||
revise_keys (list): A list of customized keywords to modify the
|
||||
state_dict in checkpoint. Each item is a (pattern, replacement)
|
||||
pair of the regular expression operations. Default: strip
|
||||
the prefix 'module.' by [(r'^module\\.', '')].
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
checkpoint = _load_checkpoint(filename, map_location, logger)
|
||||
# OrderedDict is a subclass of dict
|
||||
if not isinstance(checkpoint, dict):
|
||||
raise RuntimeError(
|
||||
f'No state_dict found in checkpoint file {filename}')
|
||||
|
||||
return _load_checkpoint_to_model(model, checkpoint, strict, logger,
|
||||
revise_keys)
|
||||
|
||||
|
||||
def weights_to_cpu(state_dict):
|
||||
"""Copy a model state_dict to cpu.
|
||||
|
||||
Args:
|
||||
state_dict (OrderedDict): Model weights on GPU.
|
||||
|
||||
Returns:
|
||||
OrderedDict: Model weights on GPU.
|
||||
"""
|
||||
state_dict_cpu = OrderedDict()
|
||||
for key, val in state_dict.items():
|
||||
state_dict_cpu[key] = val.cpu()
|
||||
# Keep metadata in state_dict
|
||||
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
||||
return state_dict_cpu
|
||||
|
||||
|
||||
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
||||
"""Saves module state to `destination` dictionary.
|
||||
|
||||
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to generate state_dict.
|
||||
destination (dict): A dict where state will be stored.
|
||||
prefix (str): The prefix for parameters and buffers used in this
|
||||
module.
|
||||
keep_vars (bool): Whether to keep the variable property of the
|
||||
parameters.
|
||||
"""
|
||||
for name, param in module._parameters.items():
|
||||
if param is not None:
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
for name, buf in module._buffers.items():
|
||||
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
||||
if buf is not None:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
|
||||
|
||||
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
||||
recursively check parallel module in case that the model has a complicated
|
||||
structure, e.g., nn.Module(nn.Module(DDP)).
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to generate state_dict.
|
||||
destination (OrderedDict): Returned dict for the state of the
|
||||
module.
|
||||
prefix (str): Prefix of the key.
|
||||
keep_vars (bool): Whether to keep the variable property of the
|
||||
parameters. Default: False.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing a whole state of the module.
|
||||
"""
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_model_wrapper(module):
|
||||
module = module.module
|
||||
|
||||
# below is the same as torch.nn.Module.state_dict()
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
||||
version=module._version)
|
||||
_save_to_state_dict(module, destination, prefix, keep_vars)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
get_state_dict(
|
||||
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
||||
for hook in module._state_dict_hooks.values():
|
||||
hook_result = hook(module, destination, prefix, local_metadata)
|
||||
if hook_result is not None:
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint, filename, file_client_args=None):
|
||||
"""Save checkpoint to file.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): Module whose params are to be saved.
|
||||
filename (str): Checkpoint filename.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to None.
|
||||
"""
|
||||
if filename.startswith('pavi://'):
|
||||
if file_client_args is not None:
|
||||
raise ValueError(
|
||||
'file_client_args should be "None" if filename starts with'
|
||||
f'"pavi://", but got {file_client_args}')
|
||||
try:
|
||||
from pavi import exception, modelcloud
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install pavi to load checkpoint from modelcloud.')
|
||||
model_path = filename[7:]
|
||||
root = modelcloud.Folder()
|
||||
model_dir, model_name = osp.split(model_path)
|
||||
try:
|
||||
model = modelcloud.get(model_dir)
|
||||
except exception.NodeNotFoundError:
|
||||
model = root.create_training_model(model_dir)
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
checkpoint_file = osp.join(tmp_dir, model_name)
|
||||
with open(checkpoint_file, 'wb') as f:
|
||||
torch.save(checkpoint, f)
|
||||
f.flush()
|
||||
model.create_file(checkpoint_file, name=model_name)
|
||||
else:
|
||||
file_client = FileClient.infer_client(file_client_args, filename)
|
||||
with io.BytesIO() as f:
|
||||
torch.save(checkpoint, f)
|
||||
file_client.put(f.getvalue(), filename)
|
@ -45,6 +45,7 @@ def get_priority(priority: Union[int, str, Priority]) -> int:
|
||||
|
||||
Args:
|
||||
priority (int or str or :obj:`Priority`): Priority.
|
||||
|
||||
Returns:
|
||||
int: The priority value.
|
||||
"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hub import load_url
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
has_method, import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
@ -19,5 +20,5 @@ __all__ = [
|
||||
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
|
||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION'
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url'
|
||||
]
|
||||
|
128
mmengine/utils/hub.py
Normal file
128
mmengine/utils/hub.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based
|
||||
# file format. It will cause RuntimeError when a checkpoint was saved in
|
||||
# torch >= 1.6.0 but loaded in torch < 1.7.0.
|
||||
# More details at https://github.com/open-mmlab/mmpose/issues/904
|
||||
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .path import mkdir_or_exist
|
||||
from .version_utils import digit_version
|
||||
|
||||
if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
||||
'1.7.0'):
|
||||
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import zipfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, _get_torch_home, download_url_to_file
|
||||
|
||||
# Hub used to support automatically extracts from zipfile manually
|
||||
# compressed by users. The legacy zip format expects only one file from
|
||||
# torch.save() < 1.6 in the zip. We should remove this support since
|
||||
# zipfile is now default zipfile format for torch.save().
|
||||
def _is_legacy_zip_format(filename):
|
||||
if zipfile.is_zipfile(filename):
|
||||
infolist = zipfile.ZipFile(filename).infolist()
|
||||
return len(infolist) == 1 and not infolist[0].is_dir()
|
||||
return False
|
||||
|
||||
def _legacy_zip_load(filename, model_dir, map_location):
|
||||
warnings.warn(
|
||||
'Falling back to the old format < 1.6. This support will'
|
||||
' be deprecated in favor of default zipfile format '
|
||||
'introduced in 1.6. Please redo torch.save() to save it '
|
||||
'in the new zipfile format.', DeprecationWarning)
|
||||
# Note: extractall() defaults to overwrite file if exists. No need to
|
||||
# clean up beforehand. We deliberately don't handle tarfile here
|
||||
# since our legacy serialization format was in tar.
|
||||
# E.g. resnet18-5c106cde.pth which is widely used.
|
||||
with zipfile.ZipFile(filename) as f:
|
||||
members = f.infolist()
|
||||
if len(members) != 1:
|
||||
raise RuntimeError(
|
||||
'Only one file(not dir) is allowed in the zipfile')
|
||||
f.extractall(model_dir)
|
||||
extraced_name = members[0].filename
|
||||
extracted_file = os.path.join(model_dir, extraced_name)
|
||||
return torch.load(extracted_file, map_location=map_location)
|
||||
|
||||
def load_url(url,
|
||||
model_dir=None,
|
||||
map_location=None,
|
||||
progress=True,
|
||||
check_hash=False,
|
||||
file_name=None):
|
||||
r"""Loads the Torch serialized object at the given URL.
|
||||
If downloaded file is a zip file, it will be automatically decompressed
|
||||
If the object is already present in `model_dir`, it's deserialized and
|
||||
returned.
|
||||
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
|
||||
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
|
||||
Args:
|
||||
url (str): URL of the object to download
|
||||
model_dir (str, optional): directory in which to save the object
|
||||
map_location (optional): a function or a dict specifying how to
|
||||
remap storage locations (see torch.load)
|
||||
progress (bool, optional): whether or not to display a progress bar
|
||||
to stderr. Default: True
|
||||
check_hash(bool, optional): If True, the filename part of the URL
|
||||
should follow the naming convention ``filename-<sha256>.ext``
|
||||
where ``<sha256>`` is the first eight or more digits of the
|
||||
SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file.
|
||||
Default: False
|
||||
file_name (str, optional): name for the downloaded file. Filename
|
||||
from ``url`` will be used if not set. Default: None.
|
||||
Example:
|
||||
>>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106'
|
||||
... 'cde.pth')
|
||||
>>> state_dict = torch.hub.load_state_dict_from_url(url)
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
warnings.warn(
|
||||
'TORCH_MODEL_ZOO is deprecated, please use env '
|
||||
'TORCH_HOME instead', DeprecationWarning)
|
||||
|
||||
if model_dir is None:
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
|
||||
mkdir_or_exist(model_dir)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(
|
||||
url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(
|
||||
url, cached_file, hash_prefix, progress=progress)
|
||||
|
||||
if _is_legacy_zip_format(cached_file):
|
||||
return _legacy_zip_load(cached_file, model_dir, map_location)
|
||||
|
||||
try:
|
||||
return torch.load(cached_file, map_location=map_location)
|
||||
except RuntimeError as error:
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.5.0'):
|
||||
warnings.warn(
|
||||
f'If the error is the same as "{cached_file} is a zip '
|
||||
'archive (did you mean to use torch.jit.load()?)", you can'
|
||||
' upgrade your torch to 1.5.0 or higher (current torch '
|
||||
f'version is {TORCH_VERSION}). The error was raised '
|
||||
' because the checkpoint was saved in torch>=1.6.0 but '
|
||||
'loaded in torch<1.5.')
|
||||
raise error
|
||||
else:
|
||||
from torch.utils.model_zoo import load_url # type: ignore # noqa: F401
|
Loading…
x
Reference in New Issue
Block a user