mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Try to fix lint issue (#199)
* try to fix lint * upgrade yapf version * use another way to bypass yapf * update docstring
This commit is contained in:
parent
9d0d7536c8
commit
96f3d97fc4
@ -63,7 +63,7 @@ jobs:
|
||||
|
||||
build_cu102:
|
||||
machine:
|
||||
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
|
||||
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
|
||||
resource_class: gpu.nvidia.small
|
||||
steps:
|
||||
- checkout
|
||||
|
@ -8,6 +8,7 @@ TORCH_VERSION = torch.__version__
|
||||
|
||||
|
||||
def is_rocm_pytorch() -> bool:
|
||||
"""Check whether the PyTorch is compiled on ROCm."""
|
||||
is_rocm = False
|
||||
if TORCH_VERSION != 'parrots':
|
||||
try:
|
||||
@ -20,6 +21,7 @@ def is_rocm_pytorch() -> bool:
|
||||
|
||||
|
||||
def _get_cuda_home() -> Optional[str]:
|
||||
"""Obtain the path of CUDA home."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.utils.build_extension import CUDA_HOME
|
||||
else:
|
||||
@ -32,6 +34,7 @@ def _get_cuda_home() -> Optional[str]:
|
||||
|
||||
|
||||
def get_build_config():
|
||||
"""Obtain the build information of PyTorch or Parrots."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.config import get_build_info
|
||||
return get_build_info()
|
||||
@ -40,6 +43,8 @@ def get_build_config():
|
||||
|
||||
|
||||
def _get_conv() -> tuple:
|
||||
"""A wrapper to obtain base classes of Conv layers from PyTorch or
|
||||
Parrots."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
||||
else:
|
||||
@ -48,6 +53,7 @@ def _get_conv() -> tuple:
|
||||
|
||||
|
||||
def _get_dataloader() -> tuple:
|
||||
"""A wrapper to obtain DataLoader class from PyTorch or Parrots."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from torch.utils.data import DataLoader, PoolDataLoader
|
||||
else:
|
||||
@ -57,6 +63,7 @@ def _get_dataloader() -> tuple:
|
||||
|
||||
|
||||
def _get_extension():
|
||||
"""A wrapper to obtain extension class from PyTorch or Parrots."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.utils.build_extension import BuildExtension, Extension
|
||||
CppExtension = partial(Extension, cuda=False)
|
||||
@ -68,6 +75,8 @@ def _get_extension():
|
||||
|
||||
|
||||
def _get_pool() -> tuple:
|
||||
"""A wrapper to obtain base classes of pooling layers from PyTorch or
|
||||
Parrots."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
|
||||
_AdaptiveMaxPoolNd, _AvgPoolNd,
|
||||
@ -80,6 +89,8 @@ def _get_pool() -> tuple:
|
||||
|
||||
|
||||
def _get_norm() -> tuple:
|
||||
"""A wrapper to obtain base classes of normalization layers from PyTorch or
|
||||
Parrots."""
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
|
||||
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
|
||||
|
@ -4,6 +4,14 @@ __version__ = '0.0.1'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
"""Parse the version information.
|
||||
|
||||
Args:
|
||||
version_str (str): version string like '0.0.1'.
|
||||
|
||||
Returns:
|
||||
tuple: version information contains major, minor, micro version.
|
||||
"""
|
||||
version_info = []
|
||||
for x in version_str.split('.'):
|
||||
if x.isdigit():
|
||||
|
@ -36,8 +36,11 @@ class ToyModel(nn.Module):
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, data_batch, return_loss=False):
|
||||
inputs, labels = zip(
|
||||
*map(lambda x: (x['inputs'], x['data_sample']), data_batch))
|
||||
inputs, labels = [], []
|
||||
for x in data_batch:
|
||||
inputs.append(x['inputs'])
|
||||
labels.append(x['data_sample'])
|
||||
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
inputs = torch.stack(inputs).to(device)
|
||||
labels = torch.stack(labels).to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user