[Imporve] Using `train_step` instead of `forward` in PreciseBNHook (#964)

* fix precise BN hook when using MLU

* fix unit tests
pull/942/head
Ezra-Yu 2022-08-11 15:02:25 +08:00 committed by GitHub
parent b366897889
commit e54cfd6951
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 7 deletions

View File

@ -107,7 +107,7 @@ def update_bn_stats(model: nn.Module,
prog_bar = mmcv.ProgressBar(num_iter)
for data in itertools.islice(loader, num_iter):
model(**data)
model.train_step(data)
for i, bn in enumerate(bn_layers):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter

View File

@ -10,6 +10,7 @@ from mmcv.utils.logging import print_log
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import PreciseBNHook
from mmcls.models.classifiers import BaseClassifier
class ExampleDataset(Dataset):
@ -41,7 +42,7 @@ class BiggerDataset(ExampleDataset):
return 12
class ExampleModel(nn.Module):
class ExampleModel(BaseClassifier):
def __init__(self):
super().__init__()
@ -52,7 +53,17 @@ class ExampleModel(nn.Module):
def forward(self, imgs, return_loss=False):
return self.bn(self.conv(imgs))
def train_step(self, data_batch, optimizer, **kwargs):
def simple_test(self, img, img_metas=None, **kwargs):
return {}
def extract_feat(self, img, stage='neck'):
return ()
def forward_train(self, img, gt_label, **kwargs):
return {'loss': 0.5}
def train_step(self, data_batch, optimizer=None, **kwargs):
self.forward(**data_batch)
outputs = {
'loss': 0.5,
'log_vars': {
@ -234,10 +245,8 @@ def test_precise_bn():
mean = np.mean([np.mean(batch) for batch in imgs_list])
# bassel correction used in Pytorch, therefore ddof=1
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
assert np.equal(mean, np.array(
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
assert np.equal(var, np.array(
model.bn.running_var)), (var, np.array(model.bn.running_var))
assert np.equal(mean, model.bn.running_mean)
assert np.equal(var, model.bn.running_var)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')