Update CONTRIBUTING.md according to mmcv (#210)

* Update CONTRIBUTING.md according to mmcv

* Docstring formatting by docformatter

* Update openmmlab website.
pull/220/head
mzr1996 2021-04-14 21:22:37 +08:00 committed by GitHub
parent 3f085026cf
commit b7b520881f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 163 additions and 119 deletions

View File

@ -1,4 +1,4 @@
# Contributing to MMClassification
# Contributing to OpenMMLab
All kinds of contributions are welcome, including but not limited to the following.
@ -7,24 +7,63 @@ All kinds of contributions are welcome, including but not limited to the followi
## Workflow
1. Fork and pull the latest mmclassification
2. Checkout a new branch with a meaningful name (do not use master branch for PRs)
3. Commit your changes
4. Create a PR
1. fork and pull the latest OpenMMLab repository (mmclassification)
3. checkout a new branch (do not use master branch for PRs)
4. commit your changes
5. create a PR
Note
- If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
- If you are the author of some papers and would like to include your method to mmclassification,
please contact Lei Yang (jerryyanglei@gmail). We will much appreciate your contribution.
Note: If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
## Code style
### Python
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
We use [flake8](http://flake8.pycqa.org/en/latest/) as the linter and [yapf](https://github.com/google/yapf) as the formatter.
Please upgrade to the latest yapf (>=0.27.0) and refer to the [configuration](.style.yapf).
We use the following tools for linting and formatting:
- [flake8](http://flake8.pycqa.org/en/latest/): A wrapper around some linter tools.
- [yapf](https://github.com/google/yapf): A formatter for Python files.
- [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.
- [markdownlint](https://github.com/markdownlint/markdownlint): A linter to check markdown files and flag style issues.
- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
Style configurations of yapf and isort can be found in [setup.cfg](./setup.cfg).
We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,
fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.
The config for a pre-commit hook is stored in [.pre-commit-config](./.pre-commit-config.yaml).
After you clone the repository, you will need to install initialize pre-commit hook.
```shell
pip install -U pre-commit
```
From the repository folder
```shell
pre-commit install
```
Try the following steps to install ruby when you encounter an issue on installing markdownlint
```shell
# install rvm
curl -L https://get.rvm.io | bash -s -- --autolibs=read-fail
[[ -s "$HOME/.rvm/scripts/rvm" ]] && source "$HOME/.rvm/scripts/rvm"
rvm autolibs disable
# install ruby
rvm install 2.7.1
```
Or refer to [this repo](https://github.com/innerlee/setup) and take [`zzruby.sh`](https://github.com/innerlee/setup/blob/master/zzruby.sh) according its instruction.
After this on every commit check code linters and formatter will be enforced.
>Before you create a PR, make sure that your code lints and is formatted by yapf.
### C++ and CUDA
We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).

View File

@ -1,32 +1,50 @@
exclude: ^tests/data/
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.0
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
rev: v2.1.0
rev: v2.2.0
hooks:
- id: seed-isort-config
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21
hooks:
- id: isort
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.29.0
rev: v0.30.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
rev: v3.1.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 2.1.4
hooks:
- id: markdownlint
args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036"]
args: ["-r", "~MD002,~MD013,~MD029,~MD033,~MD034",
"-t", "allow_different_nesting"]
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
# - repo: local
# hooks:
# - id: clang-format
# name: clang-format
# description: Format files with ClangFormat
# entry: clang-format -style=google -i
# language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$

View File

@ -10,7 +10,7 @@
## Introduction
MMClassification is an open source image classification toolbox based on PyTorch. It is
a part of the [OpenMMLab](https://open-mmlab.github.io/) project.
a part of the [OpenMMLab](https://openmmlab.com/) project.
Documentation: https://mmclassification.readthedocs.io/en/latest/

View File

@ -38,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
"""Calculate precision, recall and f1 score according to the prediction and
target.
target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).
@ -204,8 +204,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
def support(pred, target, average_mode='macro'):
"""Calculate the total number of occurrences of each label according to
the prediction and target.
"""Calculate the total number of occurrences of each label according to the
prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction with shape (N, C).

View File

@ -3,7 +3,7 @@ import torch
def average_precision(pred, target):
""" Calculate the average precision for a single class
"""Calculate the average precision for a single class.
AP summarizes a precision-recall curve as the weighted mean of maximum
precisions obtained for any r'>r, where r is the recall:
@ -43,7 +43,7 @@ def average_precision(pred, target):
def mAP(pred, target):
""" Calculate the mean average precision with respect of classes
"""Calculate the mean average precision with respect of classes.
Args:
pred (torch.Tensor | np.ndarray): The model prediction with shape

View File

@ -6,8 +6,8 @@ import torch
def average_performance(pred, target, thr=None, k=None):
"""Calculate CP, CR, CF1, OP, OR, OF1, where C stands for per-class
average, O stands for overall average, P stands for precision, R
stands for recall and F1 stands for F1-score
average, O stands for overall average, P stands for precision, R stands for
recall and F1 stands for F1-score.
Args:
pred (torch.Tensor | np.ndarray): The model prediction with shape

View File

@ -102,8 +102,7 @@ class CIFAR10(BaseDataset):
@DATASETS.register_module()
class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
"""
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset."""
base_folder = 'cifar-100-python'
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'

View File

@ -84,8 +84,7 @@ class MNIST(BaseDataset):
@DATASETS.register_module()
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
Dataset.
"""
Dataset."""
resource_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' # noqa: E501
resources = {
@ -110,8 +109,9 @@ def get_int(b):
def open_maybe_compressed_file(path):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string
and ends with '.gz' or '.xz'.
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if not isinstance(path, str):
return path
@ -125,9 +125,10 @@ def open_maybe_compressed_file(path):
def read_sn3_pascalvincent_tensor(path, strict=True):
"""Read a SN3 file in "Pascal Vincent" format
(Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):

View File

@ -7,8 +7,7 @@ from .base_dataset import BaseDataset
class MultiLabelDataset(BaseDataset):
""" Multi-label Dataset.
"""
"""Multi-label Dataset."""
def get_cat_ids(self, idx):
"""Get category ids by index.

View File

@ -15,9 +15,10 @@ def random_negative(value, random_negative_prob):
@PIPELINES.register_module()
class AutoAugment(object):
"""Auto augmentation.
This data augmentation is proposed in `AutoAugment: Learning Augmentation
Policies from Data <https://arxiv.org/abs/1805.09501>`_.
"""Auto augmentation. This data augmentation is proposed in `AutoAugment:
Learning Augmentation Policies from Data.
<https://arxiv.org/abs/1805.09501>`_.
Args:
policies (list[list[dict]]): The policies of auto augmentation. Each
@ -53,9 +54,9 @@ class AutoAugment(object):
@PIPELINES.register_module()
class RandAugment(object):
"""Random augmentation.
This data augmentation is proposed in `RandAugment: Practical automated
data augmentation with a reduced search space
"""Random augmentation. This data augmentation is proposed in `RandAugment:
Practical automated data augmentation with a reduced search space.
<https://arxiv.org/abs/1909.13719>`_.
Args:

View File

@ -106,8 +106,7 @@ class ToNumpy(object):
@PIPELINES.register_module()
class Collect(object):
"""
Collect data from the loader relevant to the specific task.
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img" and "gt_label".

View File

@ -272,7 +272,6 @@ class RandomGrayscale(object):
- If input image is 1 channel: grayscale version is 1 channel.
- If input image is 3 channel: grayscale version is 3 channel
with r == g == b.
"""
def __init__(self, gray_prob=0.1):
@ -661,6 +660,7 @@ class Albu(object):
def albu_builder(self, cfg):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
@ -692,7 +692,9 @@ class Albu(object):
@staticmethod
def mapper(d, keymap):
"""Dictionary mapper. Renames keys according to keymap provided.
"""Dictionary mapper.
Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}

View File

@ -10,8 +10,7 @@ from .multi_label import MultiLabelDataset
@DATASETS.register_module()
class VOC(MultiLabelDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
"""
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
@ -26,7 +25,7 @@ class VOC(MultiLabelDataset):
raise ValueError('Cannot infer dataset year from img_prefix.')
def load_annotations(self):
"""Load annotations
"""Load annotations.
Returns:
list[dict]: Annotation info from XML file.

View File

@ -8,17 +8,15 @@ from mmcv.runner import load_checkpoint
class BaseBackbone(nn.Module, metaclass=ABCMeta):
"""Base backbone.
This class defines the basic functions of a backbone.
Any backbone that inherits this class should at least
define its own `forward` function.
This class defines the basic functions of a backbone. Any backbone that
inherits this class should at least define its own `forward` function.
"""
def __init__(self):
super(BaseBackbone, self).__init__()
def init_weights(self, pretrained=None):
"""Init backbone weights
"""Init backbone weights.
Args:
pretrained (str | None): If pretrained is a string, then it
@ -38,7 +36,7 @@ class BaseBackbone(nn.Module, metaclass=ABCMeta):
@abstractmethod
def forward(self, x):
"""Forward computation
"""Forward computation.
Args:
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
@ -47,7 +45,7 @@ class BaseBackbone(nn.Module, metaclass=ABCMeta):
pass
def train(self, mode=True):
"""Set module status before forward computation
"""Set module status before forward computation.
Args:
mode (bool): Whether it is train_mode or test_mode

View File

@ -199,7 +199,7 @@ class MobileNetV2(BaseBackbone):
self.layers.append('conv2')
def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
""" Stack InvertedResidual blocks to build a layer for MobileNetV2.
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.

View File

@ -12,7 +12,7 @@ from .base_backbone import BaseBackbone
@BACKBONES.register_module()
class MobileNetv3(BaseBackbone):
""" MobileNetv3 backbone
"""MobileNetv3 backbone.
Args:
arch (str): Architechture of mobilnetv3, from {small, big}.

View File

@ -273,7 +273,7 @@ class RegNet(ResNet):
return widths, groups
def get_stages_from_blocks(self, widths):
"""Gets widths/stage_blocks of network at each stage
"""Gets widths/stage_blocks of network at each stage.
Args:
widths (list[int]): Width in each stage.

View File

@ -634,13 +634,13 @@ class ResNet(BaseBackbone):
@BACKBONES.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d variant described in
`Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_.
"""ResNetV1d variant described in `Bag of Tricks.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv
in the input stem with three 3x3 convs. And in the downsampling block,
a 2x2 avg_pool with stride 2 is added before conv, whose stride is
changed to 1.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):

View File

@ -271,7 +271,7 @@ class ShuffleNetV1(BaseBackbone):
f'{type(pretrained)}')
def make_layer(self, out_channels, num_blocks, first_block=False):
""" Stack ShuffleUnit blocks to make a layer.
"""Stack ShuffleUnit blocks to make a layer.
Args:
out_channels (int): out_channels of the block.

View File

@ -222,7 +222,7 @@ class ShuffleNetV2(BaseBackbone):
act_cfg=act_cfg))
def _make_layer(self, out_channels, num_blocks):
""" Stack blocks to make a layer.
"""Stack blocks to make a layer.
Args:
out_channels (int): out_channels of the block.

View File

@ -20,7 +20,7 @@ except ImportError:
class BaseClassifier(nn.Module, metaclass=ABCMeta):
"""Base class for classifiers"""
"""Base class for classifiers."""
def __init__(self):
super(BaseClassifier, self).__init__()
@ -81,13 +81,14 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta):
@auto_fp16(apply_to=('img', ))
def forward(self, img, return_loss=True, **kwargs):
"""
Calls either forward_train or forward_test depending on whether
return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_meta are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_meta are single-nested (i.e. Tensor and
List[dict]), and when `resturn_loss=False`, img and img_meta should be
double nested (i.e. List[Tensor], List[List[dict]]), with the outer
list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(img, **kwargs)

View File

@ -44,8 +44,7 @@ class ImageClassifier(BaseClassifier):
self.head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone + neck
"""
"""Directly extract features from the backbone + neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)

View File

@ -4,9 +4,7 @@ import torch.nn as nn
class BaseHead(nn.Module, metaclass=ABCMeta):
"""Base head.
"""
"""Base head."""
def __init__(self):
super(BaseHead, self).__init__()

View File

@ -11,7 +11,6 @@ class MultiLabelClsHead(BaseHead):
Args:
loss (dict): Config of classification loss.
"""
def __init__(self,

View File

@ -15,7 +15,6 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
loss (dict): Config of classification loss.
"""
def __init__(self,

View File

@ -69,7 +69,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
def accuracy(pred, target, topk=1, thrs=None):
"""Calculate accuracy according to the prediction and target
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
@ -112,7 +112,7 @@ def accuracy(pred, target, topk=1, thrs=None):
class Accuracy(nn.Module):
def __init__(self, topk=(1, )):
"""Module to calculate the accuracy
"""Module to calculate the accuracy.
Args:
topk (tuple): The criterion used to calculate the
@ -122,7 +122,7 @@ class Accuracy(nn.Module):
self.topk = topk
def forward(self, pred, target):
"""Forward function to calculate accuracy
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.

View File

@ -13,7 +13,7 @@ def asymmetric_loss(pred,
clip=0.05,
reduction='mean',
avg_factor=None):
"""asymmetric loss
"""asymmetric loss.
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`_ for
details.
@ -63,7 +63,7 @@ def asymmetric_loss(pred,
@LOSSES.register_module()
class AsymmetricLoss(nn.Module):
"""asymmetric loss
"""asymmetric loss.
Args:
gamma_pos (float): positive focusing parameter.
@ -74,7 +74,7 @@ class AsymmetricLoss(nn.Module):
reduction (str): The method used to reduce the loss into
a scalar.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
"""
def __init__(self,
gamma_pos=0.0,
@ -95,8 +95,7 @@ class AsymmetricLoss(nn.Module):
weight=None,
avg_factor=None,
reduction_override=None):
"""asymmetric loss
"""
"""asymmetric loss."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)

View File

@ -103,7 +103,7 @@ def binary_cross_entropy(pred,
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss
"""Cross entropy loss.
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid

View File

@ -12,7 +12,7 @@ def sigmoid_focal_loss(pred,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""Sigmoid focal loss
"""Sigmoid focal loss.
Args:
pred (torch.Tensor): The prediction with shape (N, *).
@ -52,7 +52,7 @@ def sigmoid_focal_loss(pred,
@LOSSES.register_module()
class FocalLoss(nn.Module):
"""Focal loss
"""Focal loss.
Args:
gamma (float): Focusing parameter in focal loss.
@ -82,7 +82,7 @@ class FocalLoss(nn.Module):
weight=None,
avg_factor=None,
reduction_override=None):
"""Sigmoid focal loss
"""Sigmoid focal loss.
Args:
pred (torch.Tensor): The prediction with shape (N, *).

View File

@ -38,10 +38,8 @@ class LabelSmoothLoss(CrossEntropyLoss):
self._eps = np.finfo(np.float32).eps
def generate_one_hot_like_label(self, label):
"""
This function takes one-hot or index label vectors and computes
one-hot like label vectors (float)
"""
"""This function takes one-hot or index label vectors and computes one-
hot like label vectors (float)"""
label_shape_list = list(label.size())
# check if targets are inputted as class integers
if len(label_shape_list) == 1 or (len(label_shape_list) == 2
@ -50,11 +48,9 @@ class LabelSmoothLoss(CrossEntropyLoss):
return label.float()
def smooth_label(self, one_hot_like_label):
"""
This function takes one-hot like target vectors and
computes smoothed target vectors (normalized)
according to the loss's smoothing parameter
"""
"""This function takes one-hot like target vectors and computes
smoothed target vectors (normalized) according to the loss's smoothing
parameter."""
assert self.num_classes > 0
one_hot_like_label /= self._eps + one_hot_like_label.sum(
dim=1, keepdim=True)

View File

@ -100,8 +100,8 @@ def weighted_loss(loss_func):
def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
"""This function converts target class indices to one-hot vectors,
given the number of classes.
"""This function converts target class indices to one-hot vectors, given
the number of classes.
Args:
targets (Tensor): The ground truth label of the prediction

View File

@ -8,10 +8,9 @@ from ..builder import NECKS
class GlobalAveragePooling(nn.Module):
"""Global Average Pooling neck.
Note that we use `view` to remove extra channel after pooling.
We do not use `squeeze` as it will also remove the batch dimension
when the tensor has a batch dimension of size 1, which can lead to
unexpected errors.
Note that we use `view` to remove extra channel after pooling. We do not
use `squeeze` as it will also remove the batch dimension when the tensor
has a batch dimension of size 1, which can lead to unexpected errors.
"""
def __init__(self):

View File

@ -6,7 +6,7 @@ from .se_layer import SELayer
class InvertedResidual(nn.Module):
"""Inverted Residual Block
"""Inverted Residual Block.
Args:
in_channels (int): The input channels of this Module.
@ -31,7 +31,6 @@ class InvertedResidual(nn.Module):
Returns:
Tensor: The output tensor.
"""
def __init__(self,

View File

@ -6,7 +6,7 @@ from torch.distributions.beta import Beta
class BaseMixupLayer(object, metaclass=ABCMeta):
"""Base class for MixupLayer"""
"""Base class for MixupLayer."""
def __init__(self):
super(BaseMixupLayer, self).__init__()

View File

@ -5,6 +5,7 @@ __version__ = '0.10.0'
def parse_version_info(version_str):
"""Parse a version string into a tuple.
Args:
version_str (str): The version string.
Returns:

View File

@ -371,7 +371,7 @@ def test_bottleneck_reslayer():
def test_resnet():
"""Test resnet backbone"""
"""Test resnet backbone."""
with pytest.raises(KeyError):
# ResNet depth should be in [18, 34, 50, 101, 152]
ResNet(20)

View File

@ -120,7 +120,7 @@ def test_res_layer():
def test_seresnet():
"""Test resnet backbone"""
"""Test resnet backbone."""
with pytest.raises(KeyError):
# SEResNet depth should be in [50, 101, 152]
SEResNet(20)

View File

@ -15,7 +15,7 @@ def check_norm_state(modules, train_state):
def test_vgg():
"""Test VGG backbone"""
"""Test VGG backbone."""
with pytest.raises(KeyError):
# VGG depth should be in [11, 13, 16, 19]
VGG(18)

View File

@ -56,8 +56,7 @@ def get_layer_maps(layer_num, with_bn):
def convert(src, dst, layer_num, with_bn=False):
"""Convert keys in torchvision pretrained VGG models to mmcls
style."""
"""Convert keys in torchvision pretrained VGG models to mmcls style."""
# load pytorch model
assert os.path.isfile(src), f'no checkpoint found at {src}'