mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Imporve] Using train_step
instead of forward
in PreciseBNHook (#964)
* fix precise BN hook when using MLU * fix unit tests
This commit is contained in:
parent
b366897889
commit
e54cfd6951
@ -107,7 +107,7 @@ def update_bn_stats(model: nn.Module,
|
|||||||
prog_bar = mmcv.ProgressBar(num_iter)
|
prog_bar = mmcv.ProgressBar(num_iter)
|
||||||
|
|
||||||
for data in itertools.islice(loader, num_iter):
|
for data in itertools.islice(loader, num_iter):
|
||||||
model(**data)
|
model.train_step(data)
|
||||||
for i, bn in enumerate(bn_layers):
|
for i, bn in enumerate(bn_layers):
|
||||||
running_means[i] += bn.running_mean / num_iter
|
running_means[i] += bn.running_mean / num_iter
|
||||||
running_vars[i] += bn.running_var / num_iter
|
running_vars[i] += bn.running_var / num_iter
|
||||||
|
@ -10,6 +10,7 @@ from mmcv.utils.logging import print_log
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from mmcls.core.hook import PreciseBNHook
|
from mmcls.core.hook import PreciseBNHook
|
||||||
|
from mmcls.models.classifiers import BaseClassifier
|
||||||
|
|
||||||
|
|
||||||
class ExampleDataset(Dataset):
|
class ExampleDataset(Dataset):
|
||||||
@ -41,7 +42,7 @@ class BiggerDataset(ExampleDataset):
|
|||||||
return 12
|
return 12
|
||||||
|
|
||||||
|
|
||||||
class ExampleModel(nn.Module):
|
class ExampleModel(BaseClassifier):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -52,7 +53,17 @@ class ExampleModel(nn.Module):
|
|||||||
def forward(self, imgs, return_loss=False):
|
def forward(self, imgs, return_loss=False):
|
||||||
return self.bn(self.conv(imgs))
|
return self.bn(self.conv(imgs))
|
||||||
|
|
||||||
def train_step(self, data_batch, optimizer, **kwargs):
|
def simple_test(self, img, img_metas=None, **kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def extract_feat(self, img, stage='neck'):
|
||||||
|
return ()
|
||||||
|
|
||||||
|
def forward_train(self, img, gt_label, **kwargs):
|
||||||
|
return {'loss': 0.5}
|
||||||
|
|
||||||
|
def train_step(self, data_batch, optimizer=None, **kwargs):
|
||||||
|
self.forward(**data_batch)
|
||||||
outputs = {
|
outputs = {
|
||||||
'loss': 0.5,
|
'loss': 0.5,
|
||||||
'log_vars': {
|
'log_vars': {
|
||||||
@ -234,10 +245,8 @@ def test_precise_bn():
|
|||||||
mean = np.mean([np.mean(batch) for batch in imgs_list])
|
mean = np.mean([np.mean(batch) for batch in imgs_list])
|
||||||
# bassel correction used in Pytorch, therefore ddof=1
|
# bassel correction used in Pytorch, therefore ddof=1
|
||||||
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
|
var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
|
||||||
assert np.equal(mean, np.array(
|
assert np.equal(mean, model.bn.running_mean)
|
||||||
model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
|
assert np.equal(var, model.bn.running_var)
|
||||||
assert np.equal(var, np.array(
|
|
||||||
model.bn.running_var)), (var, np.array(model.bn.running_var))
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user