Merge remote-tracking branch 'origin/master' into dev
commit
5ad3bed2cd
mmcls/models/classifiers
tests/test_models
tools/analysis_tools
|
@ -33,6 +33,13 @@ class ImageClassifier(BaseClassifier):
|
|||
if augments_cfg is not None:
|
||||
self.augments = Augments(augments_cfg)
|
||||
|
||||
def forward_dummy(self, img):
|
||||
"""Used for computing network flops.
|
||||
|
||||
See `mmclassificaiton/tools/analysis_tools/get_flops.py`
|
||||
"""
|
||||
return self.extract_feat(img, stage='pre_logits')
|
||||
|
||||
def extract_feat(self, img, stage='neck'):
|
||||
"""Directly extract features from the specified stage.
|
||||
|
||||
|
|
|
@ -321,3 +321,6 @@ def test_classifier_extract_feat():
|
|||
outs = model.extract_feats(multi_imgs, stage='pre_logits')
|
||||
for out_per_img in outs:
|
||||
assert out_per_img.shape == (1, 1024)
|
||||
|
||||
out = model.forward_dummy(torch.rand(1, 3, 224, 224))
|
||||
assert out.shape == (1, 1024)
|
||||
|
|
|
@ -35,8 +35,8 @@ def main():
|
|||
model = build_classifier(cfg.model)
|
||||
model.eval()
|
||||
|
||||
if hasattr(model, 'extract_feat'):
|
||||
model.forward = model.extract_feat
|
||||
if hasattr(model, 'forward_dummy'):
|
||||
model.forward = model.forward_dummy
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'FLOPs counter is currently not currently supported with {}'.
|
||||
|
|
Loading…
Reference in New Issue