Merge remote-tracking branch 'origin/master' into dev

pull/992/head
mzr1996 2022-08-22 10:12:24 +08:00
commit 5ad3bed2cd
3 changed files with 12 additions and 2 deletions
mmcls/models/classifiers
tests/test_models
tools/analysis_tools

View File

@ -33,6 +33,13 @@ class ImageClassifier(BaseClassifier):
if augments_cfg is not None:
self.augments = Augments(augments_cfg)
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmclassificaiton/tools/analysis_tools/get_flops.py`
"""
return self.extract_feat(img, stage='pre_logits')
def extract_feat(self, img, stage='neck'):
"""Directly extract features from the specified stage.

View File

@ -321,3 +321,6 @@ def test_classifier_extract_feat():
outs = model.extract_feats(multi_imgs, stage='pre_logits')
for out_per_img in outs:
assert out_per_img.shape == (1, 1024)
out = model.forward_dummy(torch.rand(1, 3, 224, 224))
assert out.shape == (1, 1024)

View File

@ -35,8 +35,8 @@ def main():
model = build_classifier(cfg.model)
model.eval()
if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.