mmpretrain/tests/test_runtime/test_preciseBN_hook.py

266 lines
8.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
from mmcv.utils import get_logger
from mmcv.utils.logging import print_log
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import PreciseBNHook
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
def __getitem__(self, idx):
results = dict(imgs=torch.tensor([1.0], dtype=torch.float32))
return results
def __len__(self):
return 1
class BiggerDataset(ExampleDataset):
def __init__(self, fixed_values=range(0, 12)):
assert len(self) == len(fixed_values)
self.fixed_values = fixed_values
def __getitem__(self, idx):
results = dict(
imgs=torch.tensor([self.fixed_values[idx]], dtype=torch.float32))
return results
def __len__(self):
# a bigger dataset
return 12
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.bn(self.conv(imgs))
def train_step(self, data_batch, optimizer, **kwargs):
outputs = {
'loss': 0.5,
'log_vars': {
'accuracy': 0.98
},
'num_samples': 1
}
return outputs
class SingleBNModel(ExampleModel):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.bn(imgs)
class GNExampleModel(ExampleModel):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.bn = nn.GroupNorm(1, 1)
self.test_cfg = None
class NoBNExampleModel(ExampleModel):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.conv(imgs)
def test_precise_bn():
optimizer_cfg = dict(
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset, batch_size=2)
model = ExampleModel()
optimizer = build_optimizer(model, optimizer_cfg)
logger = get_logger('precise_bn')
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
with pytest.raises(AssertionError):
# num_samples must be larger than 0
precise_bn_hook = PreciseBNHook(num_samples=-1)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# interval must be larger than 0
precise_bn_hook = PreciseBNHook(interval=0)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# interval must be larger than 0
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
precise_bn_hook = PreciseBNHook(interval=0)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# only support EpochBaseRunner
runner = IterBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
precise_bn_hook = PreciseBNHook(interval=2)
runner.register_hook(precise_bn_hook)
print_log(runner)
runner.run([loader], [('train', 1)])
# test non-DDP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test DP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = MMDataParallel(model)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test model w/ gn layer
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = GNExampleModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test model without bn layer
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = NoBNExampleModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test how precise it is
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=12)
assert precise_bn_hook.num_samples == 12
assert precise_bn_hook.interval == 1
model = SingleBNModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
imgs_list = list()
for loader in loaders:
for i, data in enumerate(loader):
imgs_list.append(np.array(data['imgs']))
mean = np.mean([np.mean(batch) for batch in imgs_list])
# bassel correction used in Pytorch, therefore ddof=1
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
assert np.equal(mean, np.array(
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
assert np.equal(var, np.array(
model.bn.running_var)), (var, np.array(model.bn.running_var))
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_ddp_model_precise_bn():
# test DDP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=5)
assert precise_bn_hook.num_samples == 5
assert precise_bn_hook.interval == 1
model = ExampleModel()
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])