mmpretrain/tests/test_runtime/test_preciseBN_hook.py

265 lines
7.9 KiB
Python
Raw Normal View History

import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
from mmcv.utils import get_logger
from mmcv.utils.logging import print_log
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import PreciseBNHook
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
def __getitem__(self, idx):
results = dict(imgs=torch.tensor([1.0], dtype=torch.float32))
return results
def __len__(self):
return 1
class BiggerDataset(ExampleDataset):
def __init__(self, fixed_values=range(0, 12)):
assert len(self) == len(fixed_values)
self.fixed_values = fixed_values
def __getitem__(self, idx):
results = dict(
imgs=torch.tensor([self.fixed_values[idx]], dtype=torch.float32))
return results
def __len__(self):
# a bigger dataset
return 12
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.bn(self.conv(imgs))
def train_step(self, data_batch, optimizer, **kwargs):
outputs = {
'loss': 0.5,
'log_vars': {
'accuracy': 0.98
},
'num_samples': 1
}
return outputs
class SingleBNModel(ExampleModel):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.bn(imgs)
class GNExampleModel(ExampleModel):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.bn = nn.GroupNorm(1, 1)
self.test_cfg = None
class NoBNExampleModel(ExampleModel):
def __init__(self):
super().__init__()
self.conv = nn.Linear(1, 1)
self.test_cfg = None
def forward(self, imgs, return_loss=False):
return self.conv(imgs)
def test_precise_bn():
optimizer_cfg = dict(
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset, batch_size=2)
model = ExampleModel()
optimizer = build_optimizer(model, optimizer_cfg)
logger = get_logger('precise_bn')
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
with pytest.raises(AssertionError):
# num_samples must be larger than 0
precise_bn_hook = PreciseBNHook(num_samples=-1)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# interval must be larger than 0
precise_bn_hook = PreciseBNHook(interval=0)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# interval must be larger than 0
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
precise_bn_hook = PreciseBNHook(interval=0)
runner.register_hook(precise_bn_hook)
runner.run([loader], [('train', 1)])
with pytest.raises(AssertionError):
# only support EpochBaseRunner
runner = IterBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
precise_bn_hook = PreciseBNHook(interval=2)
runner.register_hook(precise_bn_hook)
print_log(runner)
runner.run([loader], [('train', 1)])
# test non-DDP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test DP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = MMDataParallel(model)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test model w/ gn layer
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = GNExampleModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test model without bn layer
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=4)
assert precise_bn_hook.num_samples == 4
assert precise_bn_hook.interval == 1
model = NoBNExampleModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
# test how precise it is
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=12)
assert precise_bn_hook.num_samples == 12
assert precise_bn_hook.interval == 1
model = SingleBNModel()
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])
imgs_list = list()
for loader in loaders:
for i, data in enumerate(loader):
imgs_list.append(np.array(data['imgs']))
mean = np.mean([np.mean(batch) for batch in imgs_list])
# bassel correction used in Pytorch, therefore ddof=1
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
assert np.equal(mean, np.array(
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
assert np.equal(var, np.array(
model.bn.running_var)), (var, np.array(model.bn.running_var))
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_ddp_model_precise_bn():
# test DDP model
test_bigger_dataset = BiggerDataset()
loader = DataLoader(test_bigger_dataset, batch_size=2)
loaders = [loader]
precise_bn_hook = PreciseBNHook(num_samples=5)
assert precise_bn_hook.num_samples == 5
assert precise_bn_hook.interval == 1
model = ExampleModel()
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger,
max_epochs=1)
runner.register_hook(precise_bn_hook)
runner.run(loaders, [('train', 1)])