[Imporve] Using `train_step` instead of `forward` in PreciseBNHook (#964)
* fix precise BN hook when using MLU * fix unit testspull/942/head
parent
b366897889
commit
e54cfd6951
|
@ -107,7 +107,7 @@ def update_bn_stats(model: nn.Module,
|
|||
prog_bar = mmcv.ProgressBar(num_iter)
|
||||
|
||||
for data in itertools.islice(loader, num_iter):
|
||||
model(**data)
|
||||
model.train_step(data)
|
||||
for i, bn in enumerate(bn_layers):
|
||||
running_means[i] += bn.running_mean / num_iter
|
||||
running_vars[i] += bn.running_var / num_iter
|
||||
|
|
|
@ -10,6 +10,7 @@ from mmcv.utils.logging import print_log
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcls.core.hook import PreciseBNHook
|
||||
from mmcls.models.classifiers import BaseClassifier
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
@ -41,7 +42,7 @@ class BiggerDataset(ExampleDataset):
|
|||
return 12
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
class ExampleModel(BaseClassifier):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -52,7 +53,17 @@ class ExampleModel(nn.Module):
|
|||
def forward(self, imgs, return_loss=False):
|
||||
return self.bn(self.conv(imgs))
|
||||
|
||||
def train_step(self, data_batch, optimizer, **kwargs):
|
||||
def simple_test(self, img, img_metas=None, **kwargs):
|
||||
return {}
|
||||
|
||||
def extract_feat(self, img, stage='neck'):
|
||||
return ()
|
||||
|
||||
def forward_train(self, img, gt_label, **kwargs):
|
||||
return {'loss': 0.5}
|
||||
|
||||
def train_step(self, data_batch, optimizer=None, **kwargs):
|
||||
self.forward(**data_batch)
|
||||
outputs = {
|
||||
'loss': 0.5,
|
||||
'log_vars': {
|
||||
|
@ -234,10 +245,8 @@ def test_precise_bn():
|
|||
mean = np.mean([np.mean(batch) for batch in imgs_list])
|
||||
# bassel correction used in Pytorch, therefore ddof=1
|
||||
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
|
||||
assert np.equal(mean, np.array(
|
||||
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
|
||||
assert np.equal(var, np.array(
|
||||
model.bn.running_var)), (var, np.array(model.bn.running_var))
|
||||
assert np.equal(mean, model.bn.running_mean)
|
||||
assert np.equal(var, model.bn.running_var)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
|
|
Loading…
Reference in New Issue