Try to fix lint issue (#199)

* try to fix lint

* upgrade yapf version

* use another way to bypass yapf

* update docstring
This commit is contained in:
Wenwei Zhang 2022-04-26 13:53:00 +08:00 committed by GitHub
parent 9d0d7536c8
commit 96f3d97fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 3 deletions

View File

@ -63,7 +63,7 @@ jobs:
build_cu102:
machine:
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
resource_class: gpu.nvidia.small
steps:
- checkout

View File

@ -8,6 +8,7 @@ TORCH_VERSION = torch.__version__
def is_rocm_pytorch() -> bool:
"""Check whether the PyTorch is compiled on ROCm."""
is_rocm = False
if TORCH_VERSION != 'parrots':
try:
@ -20,6 +21,7 @@ def is_rocm_pytorch() -> bool:
def _get_cuda_home() -> Optional[str]:
"""Obtain the path of CUDA home."""
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
@ -32,6 +34,7 @@ def _get_cuda_home() -> Optional[str]:
def get_build_config():
"""Obtain the build information of PyTorch or Parrots."""
if TORCH_VERSION == 'parrots':
from parrots.config import get_build_info
return get_build_info()
@ -40,6 +43,8 @@ def get_build_config():
def _get_conv() -> tuple:
"""A wrapper to obtain base classes of Conv layers from PyTorch or
Parrots."""
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
else:
@ -48,6 +53,7 @@ def _get_conv() -> tuple:
def _get_dataloader() -> tuple:
"""A wrapper to obtain DataLoader class from PyTorch or Parrots."""
if TORCH_VERSION == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader
else:
@ -57,6 +63,7 @@ def _get_dataloader() -> tuple:
def _get_extension():
"""A wrapper to obtain extension class from PyTorch or Parrots."""
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
CppExtension = partial(Extension, cuda=False)
@ -68,6 +75,8 @@ def _get_extension():
def _get_pool() -> tuple:
"""A wrapper to obtain base classes of pooling layers from PyTorch or
Parrots."""
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
@ -80,6 +89,8 @@ def _get_pool() -> tuple:
def _get_norm() -> tuple:
"""A wrapper to obtain base classes of normalization layers from PyTorch or
Parrots."""
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d

View File

@ -4,6 +4,14 @@ __version__ = '0.0.1'
def parse_version_info(version_str):
"""Parse the version information.
Args:
version_str (str): version string like '0.0.1'.
Returns:
tuple: version information contains major, minor, micro version.
"""
version_info = []
for x in version_str.split('.'):
if x.isdigit():

View File

@ -36,8 +36,11 @@ class ToyModel(nn.Module):
self.linear = nn.Linear(2, 1)
def forward(self, data_batch, return_loss=False):
inputs, labels = zip(
*map(lambda x: (x['inputs'], x['data_sample']), data_batch))
inputs, labels = [], []
for x in data_batch:
inputs.append(x['inputs'])
labels.append(x['data_sample'])
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
inputs = torch.stack(inputs).to(device)
labels = torch.stack(labels).to(device)