[Enhance] Remove unnecessary calls and lazily import to speed import performance (#837)
* [Enhance] Remove unnecessary calls to speed import performance * lazily import matplotlib * minor refinementpull/819/head^2
parent
fcd783fcb2
commit
c89d4ef815
|
@ -12,7 +12,6 @@ from tempfile import TemporaryDirectory
|
|||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import mmengine
|
||||
from mmengine.dist import get_dist_info
|
||||
|
@ -112,6 +111,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
|||
|
||||
|
||||
def get_torchvision_models():
|
||||
import torchvision
|
||||
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
|
||||
model_urls = dict()
|
||||
# When the version of torchvision is lower than 0.13, the model url is
|
||||
|
|
|
@ -4,7 +4,6 @@ import os.path as osp
|
|||
import subprocess
|
||||
import sys
|
||||
from collections import OrderedDict, defaultdict
|
||||
from distutils import errors
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -47,6 +46,8 @@ def collect_env():
|
|||
- OpenCV (optional): OpenCV version.
|
||||
- MMENGINE: MMENGINE version.
|
||||
"""
|
||||
from distutils import errors
|
||||
|
||||
env_info = OrderedDict()
|
||||
env_info['sys.platform'] = sys.platform
|
||||
env_info['Python'] = sys.version.replace('\n', '')
|
||||
|
|
|
@ -103,7 +103,6 @@ def _get_norm() -> tuple:
|
|||
|
||||
_ConvNd, _ConvTransposeMixin = _get_conv()
|
||||
DataLoader, PoolDataLoader = _get_dataloader()
|
||||
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
||||
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
|
||||
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
||||
|
||||
|
|
|
@ -3,9 +3,6 @@ import importlib
|
|||
import os.path as osp
|
||||
import subprocess
|
||||
|
||||
import pkg_resources
|
||||
from pkg_resources import get_distribution
|
||||
|
||||
|
||||
def is_installed(package: str) -> bool:
|
||||
"""Check package whether installed.
|
||||
|
@ -13,6 +10,12 @@ def is_installed(package: str) -> bool:
|
|||
Args:
|
||||
package (str): Name of package to be checked.
|
||||
"""
|
||||
# When executing `import mmengine.runner`,
|
||||
# pkg_resources will be imported and it takes too much time.
|
||||
# Therefore, import it in function scope to save time.
|
||||
import pkg_resources
|
||||
from pkg_resources import get_distribution
|
||||
|
||||
# refresh the pkg_resources
|
||||
# more datails at https://github.com/pypa/setuptools/issues/373
|
||||
importlib.reload(pkg_resources)
|
||||
|
@ -33,6 +36,8 @@ def get_installed_path(package: str) -> str:
|
|||
>>> get_installed_path('mmcls')
|
||||
>>> '.../lib/python3.7/site-packages/mmcls'
|
||||
"""
|
||||
from pkg_resources import get_distribution
|
||||
|
||||
# if the package name is not the same as module name, module name should be
|
||||
# inferred. For example, mmcv-full is the package name, but mmcv is module
|
||||
# name. If we want to get the installed path of mmcv-full, we should concat
|
||||
|
@ -51,6 +56,7 @@ def package2module(package: str):
|
|||
Args:
|
||||
package (str): Package to infer module name.
|
||||
"""
|
||||
from pkg_resources import get_distribution
|
||||
pkg = get_distribution(package)
|
||||
if pkg.has_metadata('top_level.txt'):
|
||||
module_name = pkg.get_metadata('top_level.txt').split('\n')[0]
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.backend_bases import CloseEvent
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
|
||||
def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
||||
|
@ -131,6 +130,7 @@ def color_str2rgb(color: str) -> tuple:
|
|||
Returns:
|
||||
tuple: RGB color.
|
||||
"""
|
||||
import matplotlib
|
||||
rgb_color: tuple = matplotlib.colors.to_rgb(color)
|
||||
rgb_color = tuple(int(c * 255) for c in rgb_color)
|
||||
return rgb_color
|
||||
|
@ -186,6 +186,8 @@ def wait_continue(figure, timeout: int = 0, continue_key: str = ' ') -> int:
|
|||
int: If zero, means time out or the user pressed ``continue_key``,
|
||||
and if one, means the user closed the show figure.
|
||||
""" # noqa: E501
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.backend_bases import CloseEvent
|
||||
is_inline = 'inline' in plt.get_backend()
|
||||
if is_inline:
|
||||
# If use inline backend, interactive input and timeout is no use.
|
||||
|
@ -226,7 +228,7 @@ def wait_continue(figure, timeout: int = 0, continue_key: str = ' ') -> int:
|
|||
return 0 # Quit for continue.
|
||||
|
||||
|
||||
def img_from_canvas(canvas: FigureCanvasAgg) -> np.ndarray:
|
||||
def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray:
|
||||
"""Get RGB image from ``FigureCanvasAgg``.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -4,16 +4,9 @@ import warnings
|
|||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
from matplotlib.collections import (LineCollection, PatchCollection,
|
||||
PolyCollection)
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.patches import Circle
|
||||
from matplotlib.pyplot import new_figure_manager
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.dist import master_only
|
||||
|
@ -240,6 +233,7 @@ class Visualizer(ManagerMixin):
|
|||
continue_key (str): The key for users to continue. Defaults to
|
||||
the space key.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
is_inline = 'inline' in plt.get_backend()
|
||||
img = self.get_image() if drawn_img is None else drawn_img
|
||||
self._init_manager(win_name)
|
||||
|
@ -302,7 +296,8 @@ class Visualizer(ManagerMixin):
|
|||
Returns:
|
||||
tuple: build canvas figure and axes.
|
||||
"""
|
||||
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
from matplotlib.figure import Figure
|
||||
fig = Figure(**fig_cfg)
|
||||
ax = fig.add_subplot()
|
||||
ax.axis(False)
|
||||
|
@ -318,6 +313,8 @@ class Visualizer(ManagerMixin):
|
|||
Args:
|
||||
win_name (str): The window name.
|
||||
"""
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.pyplot import new_figure_manager
|
||||
if getattr(self, 'manager', None) is None:
|
||||
self.manager = new_figure_manager(
|
||||
num=1, FigureClass=Figure, **self.fig_show_cfg)
|
||||
|
@ -546,6 +543,7 @@ class Visualizer(ManagerMixin):
|
|||
If ``line_widths`` is single value, all the lines will
|
||||
have the same linewidth. Defaults to 2.
|
||||
"""
|
||||
from matplotlib.collections import LineCollection
|
||||
check_type('x_datas', x_datas, (np.ndarray, torch.Tensor))
|
||||
x_datas = tensor2ndarray(x_datas)
|
||||
check_type('y_datas', y_datas, (np.ndarray, torch.Tensor))
|
||||
|
@ -614,6 +612,8 @@ class Visualizer(ManagerMixin):
|
|||
alpha (Union[int, float]): The transparency of circles.
|
||||
Defaults to 0.8.
|
||||
"""
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Circle
|
||||
check_type('center', center, (np.ndarray, torch.Tensor))
|
||||
center = tensor2ndarray(center)
|
||||
check_type('radius', radius, (np.ndarray, torch.Tensor))
|
||||
|
@ -760,6 +760,7 @@ class Visualizer(ManagerMixin):
|
|||
alpha (Union[int, float]): The transparency of polygons.
|
||||
Defaults to 0.8.
|
||||
"""
|
||||
from matplotlib.collections import PolyCollection
|
||||
check_type('polygons', polygons, (list, np.ndarray, torch.Tensor))
|
||||
edge_colors = color_val_matplotlib(edge_colors) # type: ignore
|
||||
face_colors = color_val_matplotlib(face_colors) # type: ignore
|
||||
|
@ -916,6 +917,7 @@ class Visualizer(ManagerMixin):
|
|||
Returns:
|
||||
np.ndarray: RGB image.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
assert isinstance(featmap,
|
||||
torch.Tensor), (f'`featmap` should be torch.Tensor,'
|
||||
f' but got {type(featmap)}')
|
||||
|
|
Loading…
Reference in New Issue