# Copyright (c) Alibaba, Inc. and its affiliates. import os import torch from mmcv.parallel import is_module_wrapper from mmcv.runner import load_checkpoint as mmcv_load_checkpoint from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu from torch.optim import Optimizer from easycv.file import io from easycv.utils.constant import CACHE_DIR def load_checkpoint(model, filename, map_location='cpu', strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ if not filename.startswith('oss://'): return mmcv_load_checkpoint( model, filename, map_location=map_location, strict=strict, logger=logger) else: _, fname = os.path.split(filename) cache_file = os.path.join(CACHE_DIR, fname) if not os.path.exists(cache_file): print(f'download checkpoint from {filename} to {cache_file}') io.copy(filename, cache_file) if torch.distributed.is_available( ) and torch.distributed.is_initialized(): torch.distributed.barrier() return mmcv_load_checkpoint( model, cache_file, map_location=map_location, strict=strict, logger=logger) def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') out_dir = os.path.dirname(filename) out_dir = out_dir + '/' if out_dir[-1] != '/' else out_dir if not io.isdir(out_dir): io.makedirs(out_dir) if is_module_wrapper(model): model = model.module checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() with io.open(filename, 'wb') as ofile: torch.save(checkpoint, ofile)