mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Merge branch 'master' into pytorch-1.0
This commit is contained in:
commit
afe97d53d6
@ -1,12 +1,23 @@
|
||||
import os.path as osp
|
||||
import pkgutil
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from torch.utils import model_zoo
|
||||
|
||||
|
||||
open_mmlab_model_urls = {
|
||||
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
|
||||
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501
|
||||
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501
|
||||
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501
|
||||
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth' # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
"""Load state_dict to a module.
|
||||
|
||||
@ -69,7 +80,7 @@ def load_checkpoint(model,
|
||||
|
||||
Args:
|
||||
model (Module): Module to load checkpoint.
|
||||
filename (str): Either a filepath or URL or modelzoll://xxxxxxx.
|
||||
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
|
||||
map_location (str): Same as :func:`torch.load`.
|
||||
strict (bool): Whether to allow different params for the model and
|
||||
checkpoint.
|
||||
@ -80,9 +91,19 @@ def load_checkpoint(model,
|
||||
"""
|
||||
# load checkpoint from modelzoo or file or url
|
||||
if filename.startswith('modelzoo://'):
|
||||
from torchvision.models.resnet import model_urls
|
||||
import torchvision
|
||||
model_urls = dict()
|
||||
for _, name, ispkg in pkgutil.walk_packages(
|
||||
torchvision.models.__path__):
|
||||
if not ispkg:
|
||||
_zoo = import_module('torchvision.models.{}'.format(name))
|
||||
_urls = getattr(_zoo, 'model_urls')
|
||||
model_urls.update(_urls)
|
||||
model_name = filename[11:]
|
||||
checkpoint = model_zoo.load_url(model_urls[model_name])
|
||||
elif filename.startswith('open-mmlab://'):
|
||||
model_name = filename[13:]
|
||||
checkpoint = model_zoo.load_url(open_mmlab_model_urls[model_name])
|
||||
elif filename.startswith(('http://', 'https://')):
|
||||
checkpoint = model_zoo.load_url(filename)
|
||||
else:
|
||||
|
@ -1,8 +1,14 @@
|
||||
import datetime
|
||||
|
||||
from .base import LoggerHook
|
||||
|
||||
|
||||
class TextLoggerHook(LoggerHook):
|
||||
|
||||
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
|
||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
|
||||
self.time_sec_tot = 0
|
||||
|
||||
def log(self, runner):
|
||||
if runner.mode == 'train':
|
||||
lr_str = ', '.join(
|
||||
@ -14,6 +20,12 @@ class TextLoggerHook(LoggerHook):
|
||||
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
|
||||
runner.inner_iter + 1)
|
||||
if 'time' in runner.log_buffer.output:
|
||||
self.time_sec_tot += (runner.log_buffer.output['time'] *
|
||||
self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (runner.iter + 1)
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
log_str += 'eta: {}, '.format(eta_str)
|
||||
log_str += (
|
||||
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
|
||||
format(log=runner.log_buffer.output))
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.2.3'
|
||||
__version__ = '0.2.3'
|
2
setup.py
2
setup.py
@ -32,7 +32,7 @@ setup(
|
||||
packages=find_packages(),
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
|
Loading…
x
Reference in New Issue
Block a user