mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Imporve] Using train_step
instead of forward
in PreciseBNHook (#964)
* fix precise BN hook when using MLU * fix unit tests
This commit is contained in:
parent
b366897889
commit
e54cfd6951
@ -107,7 +107,7 @@ def update_bn_stats(model: nn.Module,
|
||||
prog_bar = mmcv.ProgressBar(num_iter)
|
||||
|
||||
for data in itertools.islice(loader, num_iter):
|
||||
model(**data)
|
||||
model.train_step(data)
|
||||
for i, bn in enumerate(bn_layers):
|
||||
running_means[i] += bn.running_mean / num_iter
|
||||
running_vars[i] += bn.running_var / num_iter
|
||||
|
@ -10,6 +10,7 @@ from mmcv.utils.logging import print_log
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcls.core.hook import PreciseBNHook
|
||||
from mmcls.models.classifiers import BaseClassifier
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
@ -41,7 +42,7 @@ class BiggerDataset(ExampleDataset):
|
||||
return 12
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
class ExampleModel(BaseClassifier):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -52,7 +53,17 @@ class ExampleModel(nn.Module):
|
||||
def forward(self, imgs, return_loss=False):
|
||||
return self.bn(self.conv(imgs))
|
||||
|
||||
def train_step(self, data_batch, optimizer, **kwargs):
|
||||
def simple_test(self, img, img_metas=None, **kwargs):
|
||||
return {}
|
||||
|
||||
def extract_feat(self, img, stage='neck'):
|
||||
return ()
|
||||
|
||||
def forward_train(self, img, gt_label, **kwargs):
|
||||
return {'loss': 0.5}
|
||||
|
||||
def train_step(self, data_batch, optimizer=None, **kwargs):
|
||||
self.forward(**data_batch)
|
||||
outputs = {
|
||||
'loss': 0.5,
|
||||
'log_vars': {
|
||||
@ -234,10 +245,8 @@ def test_precise_bn():
|
||||
mean = np.mean([np.mean(batch) for batch in imgs_list])
|
||||
# bassel correction used in Pytorch, therefore ddof=1
|
||||
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
|
||||
assert np.equal(mean, np.array(
|
||||
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
|
||||
assert np.equal(var, np.array(
|
||||
model.bn.running_var)), (var, np.array(model.bn.running_var))
|
||||
assert np.equal(mean, model.bn.running_mean)
|
||||
assert np.equal(var, model.bn.running_var)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
|
Loading…
x
Reference in New Issue
Block a user