[Enhance] Adapt test cases on Ascend NPU. (#1728)
parent
4d1dbafaa2
commit
c5248b17b7
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.optim import OptimWrapper
|
||||
|
@ -79,7 +80,7 @@ class TestDenseCLHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_densecl_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
dummy_dataset = DummyDataset()
|
||||
toy_model = ToyModel().to(device)
|
||||
densecl_hook = DenseCLHook(start_iters=1)
|
||||
|
|
|
@ -8,6 +8,7 @@ from unittest.mock import ANY, MagicMock, call
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel
|
||||
|
@ -70,7 +71,7 @@ class TestEMAHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_load_state_dict(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
ema_hook = EMAHook()
|
||||
runner = Runner(
|
||||
|
@ -95,7 +96,7 @@ class TestEMAHook(TestCase):
|
|||
|
||||
def test_evaluate_on_ema(self):
|
||||
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
# Test validate on ema model
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import Runner
|
||||
|
@ -79,7 +80,7 @@ class TestSimSiamHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_simsiam_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
dummy_dataset = DummyDataset()
|
||||
toy_model = ToyModel().to(device)
|
||||
simsiam_hook = SimSiamHook(
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.optim import OptimWrapper
|
||||
|
@ -86,7 +87,7 @@ class TestSwAVHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_swav_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
dummy_dataset = DummyDataset()
|
||||
toy_model = ToyModel().to(device)
|
||||
swav_hook = SwAVHook(
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.dataset import BaseDataset, ConcatDataset, RepeatDataset
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel
|
||||
from mmengine.optim import OptimWrapper
|
||||
|
@ -130,7 +131,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
self.assertIsNone(hook.schedule[1]['batch_augments'])
|
||||
|
||||
def test_do_switch(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
loss = CrossEntropyLoss(use_soft=True)
|
||||
|
@ -205,7 +206,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
# runner.train()
|
||||
|
||||
def test_resume(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
loss = CrossEntropyLoss(use_soft=True)
|
||||
|
@ -275,7 +276,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
logs.output)
|
||||
|
||||
def test_switch_train_pipeline(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
runner = Runner(
|
||||
|
@ -324,7 +325,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
pipeline)
|
||||
|
||||
def test_switch_loss(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
runner = Runner(
|
||||
|
|
Loading…
Reference in New Issue