266 lines
8.0 KiB
Python
266 lines
8.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
|
|
from mmcv.utils import get_logger
|
|
from mmcv.utils.logging import print_log
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from mmcls.core.hook import PreciseBNHook
|
|
|
|
|
|
class ExampleDataset(Dataset):
|
|
|
|
def __init__(self):
|
|
self.index = 0
|
|
|
|
def __getitem__(self, idx):
|
|
results = dict(imgs=torch.tensor([1.0], dtype=torch.float32))
|
|
return results
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
|
|
class BiggerDataset(ExampleDataset):
|
|
|
|
def __init__(self, fixed_values=range(0, 12)):
|
|
assert len(self) == len(fixed_values)
|
|
self.fixed_values = fixed_values
|
|
|
|
def __getitem__(self, idx):
|
|
results = dict(
|
|
imgs=torch.tensor([self.fixed_values[idx]], dtype=torch.float32))
|
|
return results
|
|
|
|
def __len__(self):
|
|
# a bigger dataset
|
|
return 12
|
|
|
|
|
|
class ExampleModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Linear(1, 1)
|
|
self.bn = nn.BatchNorm1d(1)
|
|
self.test_cfg = None
|
|
|
|
def forward(self, imgs, return_loss=False):
|
|
return self.bn(self.conv(imgs))
|
|
|
|
def train_step(self, data_batch, optimizer, **kwargs):
|
|
outputs = {
|
|
'loss': 0.5,
|
|
'log_vars': {
|
|
'accuracy': 0.98
|
|
},
|
|
'num_samples': 1
|
|
}
|
|
return outputs
|
|
|
|
|
|
class SingleBNModel(ExampleModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.bn = nn.BatchNorm1d(1)
|
|
self.test_cfg = None
|
|
|
|
def forward(self, imgs, return_loss=False):
|
|
return self.bn(imgs)
|
|
|
|
|
|
class GNExampleModel(ExampleModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Linear(1, 1)
|
|
self.bn = nn.GroupNorm(1, 1)
|
|
self.test_cfg = None
|
|
|
|
|
|
class NoBNExampleModel(ExampleModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Linear(1, 1)
|
|
self.test_cfg = None
|
|
|
|
def forward(self, imgs, return_loss=False):
|
|
return self.conv(imgs)
|
|
|
|
|
|
def test_precise_bn():
|
|
optimizer_cfg = dict(
|
|
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
|
|
|
|
test_dataset = ExampleDataset()
|
|
loader = DataLoader(test_dataset, batch_size=2)
|
|
model = ExampleModel()
|
|
optimizer = build_optimizer(model, optimizer_cfg)
|
|
logger = get_logger('precise_bn')
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# num_samples must be larger than 0
|
|
precise_bn_hook = PreciseBNHook(num_samples=-1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
with pytest.raises(AssertionError):
|
|
# interval must be larger than 0
|
|
precise_bn_hook = PreciseBNHook(interval=0)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
with pytest.raises(AssertionError):
|
|
# interval must be larger than 0
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
precise_bn_hook = PreciseBNHook(interval=0)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
with pytest.raises(AssertionError):
|
|
# only support EpochBaseRunner
|
|
runner = IterBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
precise_bn_hook = PreciseBNHook(interval=2)
|
|
runner.register_hook(precise_bn_hook)
|
|
print_log(runner)
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
# test non-DDP model
|
|
test_bigger_dataset = BiggerDataset()
|
|
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
|
loaders = [loader]
|
|
precise_bn_hook = PreciseBNHook(num_samples=4)
|
|
assert precise_bn_hook.num_samples == 4
|
|
assert precise_bn_hook.interval == 1
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run(loaders, [('train', 1)])
|
|
|
|
# test DP model
|
|
test_bigger_dataset = BiggerDataset()
|
|
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
|
loaders = [loader]
|
|
precise_bn_hook = PreciseBNHook(num_samples=4)
|
|
assert precise_bn_hook.num_samples == 4
|
|
assert precise_bn_hook.interval == 1
|
|
model = MMDataParallel(model)
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run(loaders, [('train', 1)])
|
|
|
|
# test model w/ gn layer
|
|
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
|
loaders = [loader]
|
|
precise_bn_hook = PreciseBNHook(num_samples=4)
|
|
assert precise_bn_hook.num_samples == 4
|
|
assert precise_bn_hook.interval == 1
|
|
model = GNExampleModel()
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run(loaders, [('train', 1)])
|
|
|
|
# test model without bn layer
|
|
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
|
loaders = [loader]
|
|
precise_bn_hook = PreciseBNHook(num_samples=4)
|
|
assert precise_bn_hook.num_samples == 4
|
|
assert precise_bn_hook.interval == 1
|
|
model = NoBNExampleModel()
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run(loaders, [('train', 1)])
|
|
|
|
# test how precise it is
|
|
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
|
loaders = [loader]
|
|
precise_bn_hook = PreciseBNHook(num_samples=12)
|
|
assert precise_bn_hook.num_samples == 12
|
|
assert precise_bn_hook.interval == 1
|
|
model = SingleBNModel()
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run(loaders, [('train', 1)])
|
|
imgs_list = list()
|
|
for loader in loaders:
|
|
for i, data in enumerate(loader):
|
|
imgs_list.append(np.array(data['imgs']))
|
|
mean = np.mean([np.mean(batch) for batch in imgs_list])
|
|
# bassel correction used in Pytorch, therefore ddof=1
|
|
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
|
|
assert np.equal(mean, np.array(
|
|
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
|
|
assert np.equal(var, np.array(
|
|
model.bn.running_var)), (var, np.array(model.bn.running_var))
|
|
|
|
@pytest.mark.skipif(
|
|
not torch.cuda.is_available(), reason='requires CUDA support')
|
|
def test_ddp_model_precise_bn():
|
|
# test DDP model
|
|
test_bigger_dataset = BiggerDataset()
|
|
loader = DataLoader(test_bigger_dataset, batch_size=2)
|
|
loaders = [loader]
|
|
precise_bn_hook = PreciseBNHook(num_samples=5)
|
|
assert precise_bn_hook.num_samples == 5
|
|
assert precise_bn_hook.interval == 1
|
|
model = ExampleModel()
|
|
model = MMDistributedDataParallel(
|
|
model.cuda(),
|
|
device_ids=[torch.cuda.current_device()],
|
|
broadcast_buffers=False,
|
|
find_unused_parameters=True)
|
|
runner = EpochBasedRunner(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
logger=logger,
|
|
max_epochs=1)
|
|
runner.register_hook(precise_bn_hook)
|
|
runner.run(loaders, [('train', 1)])
|