[Refactor] Update almost tools and add unit tests for these tools. (#1418)

* [Refactor] Update almost tools and add unit tests for these tools.

* Fix Windows UT.
pull/1424/head
Ma Zerun 2023-03-17 10:50:51 +08:00 committed by GitHub
parent 8875e9da92
commit 4f5b38f225
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 852 additions and 262 deletions

View File

@ -200,8 +200,8 @@ workflows:
- dev-1.x
- build_cpu:
name: minimum_version_cpu
torch: 1.6.0
torchvision: 0.7.0
torch: 1.8.0
torchvision: 0.9.0
python: 3.7.16
requires:
- lint
@ -231,11 +231,11 @@ workflows:
jobs:
- build_cuda:
name: minimum_version_gpu
torch: 1.6.0
torch: 1.8.0
# Use double quotation mark to explicitly specify its type
# as string instead of number
cuda: "10.1"
cuda: "10.2"
filters:
branches:
only:
- dev-1.x
- pretrain

View File

@ -1,12 +1,40 @@
_base_ = [
'../_base_/models/vit-base-p16.py',
'../_base_/datasets/imagenet_bs64_clip_384.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
model = dict(backbone=dict(pre_norm=True, ), )
# model setting
model = dict(backbone=dict(pre_norm=True))
# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=384,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -1,12 +1,40 @@
_base_ = [
'../_base_/models/vit-base-p16.py',
'../_base_/datasets/imagenet_bs64_clip_448.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
model = dict(backbone=dict(pre_norm=True, ), )
# model setting
model = dict(backbone=dict(pre_norm=True))
# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=448,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=448,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=448),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -1,12 +1,40 @@
_base_ = [
'../_base_/models/vit-base-p16.py',
'../_base_/datasets/imagenet_bs64_clip_224.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
model = dict(backbone=dict(pre_norm=True, ), )
# model setting
model = dict(backbone=dict(pre_norm=True))
# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -1,12 +1,40 @@
_base_ = [
'../_base_/models/vit-base-p32.py',
'../_base_/datasets/imagenet_bs64_clip_384.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
model = dict(backbone=dict(pre_norm=True, ), )
# model setting
model = dict(backbone=dict(pre_norm=True))
# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=384,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -1,12 +1,40 @@
_base_ = [
'../_base_/models/vit-base-p32.py',
'../_base_/datasets/imagenet_bs64_clip_448.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
model = dict(backbone=dict(pre_norm=True, ), )
# model setting
model = dict(backbone=dict(pre_norm=True))
# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=448,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=448,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=448),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -1,12 +1,40 @@
_base_ = [
'../_base_/models/vit-base-p32.py',
'../_base_/datasets/imagenet_bs64_clip_224.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
model = dict(backbone=dict(pre_norm=True, ), )
# model setting
model = dict(backbone=dict(pre_norm=True))
# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -25,7 +25,9 @@ def main():
# build the model from a config file and a checkpoint file
try:
inferencer = ImageClassificationInferencer(args.model, args.checkpoint)
pretrained = args.checkpoint or True
inferencer = ImageClassificationInferencer(
args.model, pretrained=pretrained)
except ValueError:
raise ValueError(
f'Unavailable model "{args.model}", you can specify find a model '

View File

@ -1,3 +1,4 @@
albumentations>=0.3.2 --no-binary qudida,albumentations
colorama
requests
albumentations>=0.3.2 --no-binary qudida,albumentations # For Albumentations data transform
grad-cam >= 1.3.7 # For CAM visualization
requests # For torchserve
scikit-learn # For t-SNE visualization and unit tests.

View File

@ -1,9 +1,3 @@
codecov
flake8
interrogate
isort==4.3.21
mmdet>=3.0.0rc0
pytest
scikit-learn
xdoctest >= 0.10.0
yapf

View File

@ -0,0 +1 @@
../../color.jpg

View File

@ -0,0 +1 @@
../../color.jpg

View File

@ -0,0 +1 @@
../../../color.jpg

View File

@ -1,10 +0,0 @@
{"a": "b"}
{"mode": "train", "epoch": 1, "iter": 10, "lr": 0.01309, "memory": 0, "data_time": 0.0072, "time": 0.00727}
{"mode": "train", "epoch": 1, "iter": 20, "lr": 0.02764, "memory": 0, "data_time": 0.00044, "time": 0.00046}
{"mode": "train", "epoch": 1, "iter": 30, "lr": 0.04218, "memory": 0, "data_time": 0.00028, "time": 0.0003}
{"mode": "train", "epoch": 1, "iter": 40, "lr": 0.05673, "memory": 0, "data_time": 0.00027, "time": 0.00029}
{"mode": "train", "epoch": 2, "iter": 10, "lr": 0.17309, "memory": 0, "data_time": 0.00048, "time": 0.0005}
{"mode": "train", "epoch": 2, "iter": 20, "lr": 0.18763, "memory": 0, "data_time": 0.00038, "time": 0.0004}
{"mode": "train", "epoch": 2, "iter": 30, "lr": 0.20218, "memory": 0, "data_time": 0.00037, "time": 0.00039}
{"mode": "train", "epoch": 3, "iter": 10, "lr": 0.33305, "memory": 0, "data_time": 0.00045, "time": 0.00046}
{"mode": "train", "epoch": 3, "iter": 20, "lr": 0.34759, "memory": 0, "data_time": 0.0003, "time": 0.00032}

View File

@ -0,0 +1,21 @@
{"lr": 0.1, "data_time": 0.0061125516891479496, "loss": 2.6531384229660033, "time": 0.14429793357849122, "epoch": 1, "step": 10}
{"lr": 0.1, "data_time": 0.00030262470245361327, "loss": 2.9456406116485594, "time": 0.0219132661819458, "epoch": 1, "step": 20}
{"lr": 0.1, "data_time": 0.00022499561309814454, "loss": 3.1025198698043823, "time": 0.021793675422668458, "epoch": 1, "step": 30}
{"lr": 0.1, "data_time": 0.00023109912872314452, "loss": 2.5765398740768433, "time": 0.021819329261779784, "epoch": 1, "step": 40}
{"lr": 0.1, "data_time": 0.00023169517517089843, "loss": 2.671005058288574, "time": 0.02181088924407959, "epoch": 1, "step": 50}
{"lr": 0.1, "data_time": 0.00021798610687255858, "loss": 2.5273321866989136, "time": 0.021781444549560547, "epoch": 1, "step": 60}
{"accuracy/top1": 18.80000114440918, "step": 1}
{"lr": 0.1, "data_time": 0.0007575273513793946, "loss": 2.3254310727119445, "time": 0.02237672805786133, "epoch": 2, "step": 73}
{"lr": 0.1, "data_time": 0.0002459049224853516, "loss": 2.194095492362976, "time": 0.021792054176330566, "epoch": 2, "step": 83}
{"lr": 0.1, "data_time": 0.00027666091918945315, "loss": 2.207821953296661, "time": 0.021822547912597655, "epoch": 2, "step": 93}
{"lr": 0.1, "data_time": 0.00025298595428466795, "loss": 2.090667963027954, "time": 0.02178535461425781, "epoch": 2, "step": 103}
{"lr": 0.1, "data_time": 0.0002483367919921875, "loss": 2.18342148065567, "time": 0.021893739700317383, "epoch": 2, "step": 113}
{"lr": 0.1, "data_time": 0.00030078887939453123, "loss": 2.2274346113204957, "time": 0.022345948219299316, "epoch": 2, "step": 123}
{"accuracy/top1": 21.100000381469727, "step": 2}
{"lr": 0.1, "data_time": 0.0008128643035888672, "loss": 2.017984461784363, "time": 0.02267434597015381, "epoch": 3, "step": 136}
{"lr": 0.1, "data_time": 0.00023736953735351563, "loss": 2.0648953437805178, "time": 0.02174344062805176, "epoch": 3, "step": 146}
{"lr": 0.1, "data_time": 0.00024063587188720702, "loss": 2.0859395623207093, "time": 0.022107195854187012, "epoch": 3, "step": 156}
{"lr": 0.1, "data_time": 0.0002336740493774414, "loss": 2.1662048220634462, "time": 0.021825361251831054, "epoch": 3, "step": 166}
{"lr": 0.1, "data_time": 0.0002296924591064453, "loss": 2.1007142066955566, "time": 0.021821355819702147, "epoch": 3, "step": 176}
{"lr": 0.1, "data_time": 0.00023157596588134765, "loss": 2.0436240792274476, "time": 0.021722936630249025, "epoch": 3, "step": 186}
{"accuracy/top1": 25.600000381469727, "step": 3}

View File

@ -5,12 +5,17 @@ import random
from unittest import TestCase
from unittest.mock import ANY, call, patch
import albumentations
import mmengine
import numpy as np
import pytest
from mmpretrain.registry import TRANSFORMS
try:
import albumentations
except ImportError:
albumentations = None
def construct_toy_data():
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
@ -666,6 +671,8 @@ class TestAlbumentations(TestCase):
DEFAULT_ARGS = dict(
type='Albumentations', transforms=[dict(type='ChannelShuffle', p=1)])
@pytest.mark.skipif(
albumentations is None, reason='No Albumentations module.')
def test_assertion(self):
# Test with non-list transforms
with self.assertRaises(AssertionError):
@ -697,6 +704,8 @@ class TestAlbumentations(TestCase):
cfg['keymap'] = []
TRANSFORMS.build(cfg)
@pytest.mark.skipif(
albumentations is None, reason='No Albumentations module.')
def test_transform(self):
ori_img = np.random.randint(0, 256, (256, 256, 3), np.uint8)
results = dict(img=copy.deepcopy(ori_img))
@ -795,6 +804,8 @@ class TestAlbumentations(TestCase):
assert min(ablu_result['img'].shape[:2]) == 400
assert ablu_result['img_shape'] == (400, 600)
@pytest.mark.skipif(
albumentations is None, reason='No Albumentations module.')
def test_repr(self):
cfg = copy.deepcopy(self.DEFAULT_ARGS)
transform = TRANSFORMS.build(cfg)

418
tests/test_tools.py 100644
View File

@ -0,0 +1,418 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
import tempfile
from collections import OrderedDict
from pathlib import Path
from subprocess import PIPE, Popen
from unittest import TestCase
import mmengine
import torch
from mmengine.config import Config
from mmpretrain import ModelHub, get_model
from mmpretrain.structures import DataSample
MMPRE_ROOT = Path(__file__).parent.parent
ASSETS_ROOT = Path(__file__).parent / 'data'
class TestImageDemo(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'demo/image_demo.py',
'demo/demo.JPEG',
'mobilevit-xxsmall_3rdparty_in1k',
'--device',
'cpu',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('sea snake', out.decode())
class TestAnalyzeLogs(TestCase):
def setUp(self):
self.log_file = ASSETS_ROOT / 'vis_data.json'
self.tmpdir = tempfile.TemporaryDirectory()
self.out_file = Path(self.tmpdir.name) / 'out.png'
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/analyze_logs.py',
'cal_train_time',
self.log_file,
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('slowest epoch 2, average time is 0.0219', out.decode())
command = [
'python',
'tools/analysis_tools/analyze_logs.py',
'plot_curve',
self.log_file,
'--keys',
'accuracy/top1',
'--out',
str(self.out_file),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn(str(self.log_file), out.decode())
self.assertIn(str(self.out_file), out.decode())
self.assertTrue(self.out_file.exists())
class TestAnalyzeResults(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
dataset_cfg = dict(
type='CustomDataset',
data_root=str(ASSETS_ROOT / 'dataset'),
)
config = Config(dict(test_dataloader=dict(dataset=dataset_cfg)))
self.config_file = self.dir / 'config.py'
config.dump(self.config_file)
results = [{
'gt_label': 1,
'pred_label': 0,
'pred_score': [0.9, 0.1],
'sample_idx': 0,
}, {
'gt_label': 0,
'pred_label': 0,
'pred_score': [0.9, 0.1],
'sample_idx': 1,
}]
self.result_file = self.dir / 'results.pkl'
mmengine.dump(results, self.result_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/analyze_results.py',
self.config_file,
self.result_file,
'--out-dir',
self.tmpdir.name,
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
p.communicate()
self.assertTrue((self.dir / 'success/2.jpeg.png').exists())
self.assertTrue((self.dir / 'fail/1.JPG.png').exists())
class TestPrintConfig(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.config_file = MMPRE_ROOT / 'configs/resnet/resnet18_8xb32_in1k.py'
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/misc/print_config.py',
self.config_file,
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
out = out.decode().strip().replace('\r\n', '\n')
self.assertEqual(out,
Config.fromfile(self.config_file).pretty_text.strip())
class TestVerifyDataset(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
dataset_cfg = dict(
type='CustomDataset',
ann_file=str(self.dir / 'ann.txt'),
pipeline=[dict(type='LoadImageFromFile')],
data_root=str(ASSETS_ROOT / 'dataset'),
)
ann_file = '\n'.join(['a/2.JPG 0', 'b/2.jpeg 1', 'b/subb/3.jpg 1'])
(self.dir / 'ann.txt').write_text(ann_file)
config = Config(dict(train_dataloader=dict(dataset=dataset_cfg)))
self.config_file = Path(self.tmpdir.name) / 'config.py'
config.dump(self.config_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/misc/verify_dataset.py',
self.config_file,
'--out-path',
self.dir / 'log.log',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn(
f"{ASSETS_ROOT/'dataset/a/2.JPG'} cannot be read correctly",
out.decode().strip())
self.assertTrue((self.dir / 'log.log').exists())
class TestEvalMetric(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
results = [
DataSample().set_gt_label(1).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
DataSample().set_gt_label(0).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
]
self.result_file = self.dir / 'results.pkl'
mmengine.dump(results, self.result_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/eval_metric.py',
self.result_file,
'--metric',
'type=Accuracy',
'topk=1,2',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('accuracy/top1', out.decode())
self.assertIn('accuracy/top2', out.decode())
class TestVisScheduler(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
config = Config.fromfile(MMPRE_ROOT /
'configs/resnet/resnet18_8xb32_in1k.py')
config.param_scheduler = [
dict(
type='LinearLR',
start_factor=0.01,
by_epoch=True,
end=1,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', by_epoch=True, begin=1),
]
config.work_dir = str(self.dir)
config.train_cfg.max_epochs = 2
self.config_file = Path(self.tmpdir.name) / 'config.py'
config.dump(self.config_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/visualization/vis_scheduler.py',
self.config_file,
'--dataset-size',
'100',
'--not-show',
'--save-path',
str(self.dir / 'out.png'),
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
p.communicate()
self.assertTrue((self.dir / 'out.png').exists())
class TestPublishModel(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
ckpt = dict(
state_dict=OrderedDict({
'a': torch.tensor(1.),
}),
ema_state_dict=OrderedDict({
'step': 1,
'module.a': torch.tensor(2.),
}))
self.ckpt_file = self.dir / 'ckpt.pth'
torch.save(ckpt, self.ckpt_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/model_converters/publish_model.py',
self.ckpt_file,
self.ckpt_file,
'--dataset-type',
'ImageNet',
'--no-ema',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('and drop the EMA weights.', out.decode())
self.assertIn('Successfully generated', out.decode())
output_ckpt = re.findall(r'ckpt_\d{8}-\w{8}.pth', out.decode())
self.assertGreater(len(output_ckpt), 0)
output_ckpt = output_ckpt[0]
self.assertTrue((self.dir / output_ckpt).exists())
# The input file won't be overridden.
self.assertTrue(self.ckpt_file.exists())
class TestVisCam(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
model = get_model('mobilevit-xxsmall_3rdparty_in1k')
self.config_file = self.dir / 'config.py'
model.config.dump(self.config_file)
self.ckpt_file = self.dir / 'ckpt.pth'
torch.save(model.state_dict(), self.ckpt_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/visualization/vis_cam.py',
ASSETS_ROOT / 'color.jpg',
self.config_file,
self.ckpt_file,
'--save-path',
self.dir / 'cam.jpg',
]
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
out, _ = p.communicate()
self.assertIn('backbone.conv_1x1_exp.bn', out.decode())
self.assertTrue((self.dir / 'cam.jpg').exists())
class TestConfusionMatrix(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
self.config_file = MMPRE_ROOT / 'configs/resnet/resnet18_8xb32_in1k.py'
results = [
DataSample().set_gt_label(1).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
DataSample().set_gt_label(0).set_pred_label(0).set_pred_score(
[0.6, 0.3, 0.1]).to_dict(),
]
self.result_file = self.dir / 'results.pkl'
mmengine.dump(results, self.result_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/analysis_tools/confusion_matrix.py',
self.config_file,
self.result_file,
'--out',
self.dir / 'result.pkl',
]
Popen(command, cwd=MMPRE_ROOT, stdout=PIPE).wait()
result = mmengine.load(self.dir / 'result.pkl')
torch.testing.assert_allclose(
result, torch.tensor([
[1, 0, 0],
[1, 0, 0],
[0, 0, 0],
]))
class TestVisTsne(TestCase):
def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.dir = Path(self.tmpdir.name)
config = ModelHub.get('mobilevit-xxsmall_3rdparty_in1k').config
test_dataloader = dict(
batch_size=1,
dataset=dict(
type='CustomDataset',
data_root=str(ASSETS_ROOT / 'dataset'),
pipeline=config.test_dataloader.dataset.pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
)
config.test_dataloader = mmengine.ConfigDict(test_dataloader)
self.config_file = self.dir / 'config.py'
config.dump(self.config_file)
def tearDown(self):
self.tmpdir.cleanup()
def test_run(self):
command = [
'python',
'tools/visualization/vis_tsne.py',
self.config_file,
'--work-dir',
self.dir,
'--perplexity',
'2',
]
Popen(command, cwd=MMPRE_ROOT, stdout=PIPE).wait()
self.assertTrue(len(list(self.dir.glob('tsne_*/feat_*.png'))) > 0)
class TestGetFlops(TestCase):
def test_run(self):
command = [
'python',
'tools/analysis_tools/get_flops.py',
'mobilevit-xxsmall_3rdparty_in1k',
]
ret_code = Popen(command, cwd=MMPRE_ROOT).wait()
self.assertEqual(ret_code, 0)

View File

@ -105,16 +105,12 @@ def plot_curve_helper(log_dicts, metrics, args, legend):
def plot_curve(log_dicts, args):
"""Plot train metric-iter graph."""
# set backend and style
if args.backend is not None:
plt.switch_backend(args.backend)
# set style
try:
import seaborn as sns
sns.set_style(args.style)
except ImportError:
print("Attention: The plot style won't be applied because 'seaborn' "
'package is not installed, please install it if you want better '
'show style.')
pass
# set plot window size
wind_w, wind_h = args.window_size.split('*')
@ -161,9 +157,10 @@ def add_plot_parser(subparsers):
default=None,
help='legend of each plot')
parser_plt.add_argument(
'--backend', type=str, default=None, help='backend of plt')
parser_plt.add_argument(
'--style', type=str, default='whitegrid', help='style of plt')
'--style',
type=str,
default='whitegrid',
help='style of the figure, need `seaborn` package.')
parser_plt.add_argument('--out', type=str, default=None)
parser_plt.add_argument(
'--window-size',

View File

@ -9,7 +9,7 @@ import torch
from mmengine import DictAction
from mmpretrain.datasets import build_dataset
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from mmpretrain.visualization import UniversalVisualizer
@ -18,7 +18,8 @@ def parse_args():
description='MMCls evaluate prediction success/fail')
parser.add_argument('config', help='test config file path')
parser.add_argument('result', help='test result json/pkl file')
parser.add_argument('--out-dir', help='dir to store output files')
parser.add_argument(
'--out-dir', required=True, help='dir to store output files')
parser.add_argument(
'--topk',
default=20,
@ -51,15 +52,12 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
vis.dataset_meta = {'classes': dataset.CLASSES}
# save imgs
for result in results:
data_sample = ClsDataSample()\
.set_gt_label(result['gt_label'])\
.set_pred_label(result['pred_label'])\
.set_pred_score(result['pred_scores'])
data_info = dataset.get_data_info(result['sample_idx'])
dump_infos = []
for data_sample in results:
data_info = dataset.get_data_info(data_sample.sample_idx)
if 'img' in data_info:
img = data_info['img']
name = str(result['sample_idx'])
name = str(data_sample.sample_idx)
elif 'img_path' in data_info:
img = mmcv.imread(data_info['img_path'], channel_order='rgb')
name = Path(data_info['img_path']).name
@ -70,19 +68,20 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
vis.visualize_cls(
img, data_sample, out_file=osp.join(full_dir, name + '.png'))
for k, v in result.items():
dump = dict()
for k, v in data_sample.items():
if isinstance(v, torch.Tensor):
result[k] = v.tolist()
dump[k] = v.tolist()
else:
dump[k] = v
dump_infos.append(dump)
mmengine.dump(results, osp.join(full_dir, folder_name + '.json'))
mmengine.dump(dump_infos, osp.join(full_dir, folder_name + '.json'))
def main():
args = parse_args()
# load test results
outputs = mmengine.load(args.result)
cfg = mmengine.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
@ -91,27 +90,25 @@ def main():
cfg.test_dataloader.dataset.pipeline = []
dataset = build_dataset(cfg.test_dataloader.dataset)
outputs_list = list()
for i in range(len(outputs)):
output = dict()
output['sample_idx'] = outputs[i]['sample_idx']
output['gt_label'] = outputs[i]['gt_label']['label']
output['pred_score'] = float(
torch.max(outputs[i]['pred_label']['score']).item())
output['pred_scores'] = outputs[i]['pred_label']['score']
output['pred_label'] = outputs[i]['pred_label']['label']
outputs_list.append(output)
results = list()
for result in mmengine.load(args.result):
data_sample = DataSample()
data_sample.set_metainfo({'sample_idx': result['sample_idx']})
data_sample.set_gt_label(result['gt_label'])
data_sample.set_pred_label(result['pred_label'])
data_sample.set_pred_score(result['pred_score'])
results.append(data_sample)
# sort result
outputs_list = sorted(outputs_list, key=lambda x: x['pred_score'])
results = sorted(results, key=lambda x: torch.max(x.pred_score))
success = list()
fail = list()
for output in outputs_list:
if output['pred_label'] == output['gt_label']:
success.append(output)
for data_sample in results:
if (data_sample.pred_label == data_sample.gt_label).all():
success.append(data_sample)
else:
fail.append(output)
fail.append(data_sample)
success = success[:args.topk]
fail = fail[:args.topk]

View File

@ -29,7 +29,7 @@ def parse_args():
dest='metric_options',
help='The metric config, the key-value pair in xxx=yyy format will be '
'parsed as the metric config items. You can specify multiple metrics '
'by use multiple `--metric-options`. For list type value, you can use '
'by use multiple `--metric`. For list type value, you can use '
'"key=[a,b]" or "key=a,b", and it also allows nested list/tuple '
'values, e.g. "key=[(a,b),(c,d)]".')
args = parser.parse_args()

View File

@ -1,18 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import torch
from mmengine.analysis import get_model_complexity_info
try:
from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis,
flop_count_str, flop_count_table, parameter_count)
except ImportError:
print('You may need to install fvcore for flops computation, '
'and you can use `pip install fvcore` to set up the environment')
from fvcore.nn.print_model_statistics import _format_size
from mmengine import Config
from mmpretrain.models import build_classifier
from mmpretrain import get_model
def parse_args():
@ -29,9 +20,7 @@ def parse_args():
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
@ -39,39 +28,30 @@ def main():
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
model = build_classifier(cfg.model)
model = get_model(args.config)
model.eval()
if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
inputs = (torch.randn((1, *input_shape)), )
flops_ = FlopCountAnalysis(model, inputs)
activations_ = ActivationCountAnalysis(model, inputs)
flops = _format_size(flops_.total())
activations = _format_size(activations_.total())
params = _format_size(parameter_count(model)[''])
flop_table = flop_count_table(
flops=flops_,
activations=activations_,
show_param_shapes=True,
analysis_results = get_model_complexity_info(
model,
input_shape,
)
flop_str = flop_count_str(flops=flops_, activations=activations_)
print('\n' + flop_str)
print('\n' + flop_table)
flops = analysis_results['flops_str']
params = analysis_results['params_str']
activations = analysis_results['activations_str']
out_table = analysis_results['out_table']
out_arch = analysis_results['out_arch']
print(out_table)
print(out_arch)
split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n'
f'Activation: {activations}\n{split_line}')
print('!!!Only the backbone network is counted in FLOPs analysis.')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')

View File

@ -166,7 +166,7 @@ def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
dataset=dataset,
fold=fold,
num_splits=num_splits,
seed=cfg.seed,
seed=cfg.kfold_split_seed,
test_mode=test_mode,
)
@ -203,13 +203,13 @@ def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
class SaveInfoHook(Hook):
def after_train_epoch(self, runner):
try:
last_ckpt = find_latest_checkpoint(cfg.work_dir)
exp_info = dict(
fold=fold, last_ckpt=last_ckpt, seed=runner.seed)
fold=fold,
last_ckpt=last_ckpt,
kfold_split_seed=cfg.kfold_split_seed,
)
dump(exp_info, osp.join(root_dir, EXP_INFO_FILE))
except OSError:
pass
runner.register_hook(SaveInfoHook(), 'LOWEST')
@ -226,17 +226,14 @@ def main():
# merge cli arguments to config
cfg = merge_args(cfg, args)
# set preprocess configs to model
cfg.model.setdefault('data_preprocessor', cfg.get('data_preprocessor', {}))
# set the unify random seed
cfg.seed = args.seed or sync_random_seed()
cfg.kfold_split_seed = args.seed or sync_random_seed()
# resume from the previous experiment
if args.resume:
experiment_info = load(osp.join(cfg.work_dir, EXP_INFO_FILE))
resume_fold = experiment_info['fold']
cfg.seed = experiment_info['seed']
cfg.kfold_split_seed = experiment_info['kfold_split_seed']
resume_ckpt = experiment_info.get('last_ckpt', None)
else:
resume_fold = 0

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import fcntl
import multiprocessing
import os
from pathlib import Path
@ -10,6 +10,8 @@ from mmengine import (Config, DictAction, track_parallel_progress,
from mmpretrain.datasets import build_dataset
from mmpretrain.registry import TRANSFORMS
file_lock = multiprocessing.Lock()
def parse_args():
parser = argparse.ArgumentParser(description='Verify Dataset')
@ -66,12 +68,11 @@ class DatasetValidator():
except Exception:
with open(self.log_file_path, 'a') as f:
# add file lock to prevent multi-process writing errors
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
filepath = os.path.join(item['img_prefix'],
item['img_info']['filename'])
filepath = str(Path(item['img_path']))
file_lock.acquire()
f.write(filepath + '\n')
file_lock.release()
print(f'{filepath} cannot be read correctly, please check it.')
# Release files lock automatic using with
def __len__(self):
return len(self.dataset)
@ -81,12 +82,12 @@ def print_info(log_file_path):
"""print some information and do extra action."""
print()
with open(log_file_path, 'r') as f:
context = f.read().strip()
if context == '':
content = f.read().strip()
if content == '':
print('There is no broken file found.')
os.remove(log_file_path)
else:
num_file = len(context.split('\n'))
num_file = len(content.split('\n'))
print(f'{num_file} broken files found, name list save in file:'
f'{log_file_path}')
print()

View File

@ -1,11 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import datetime
import subprocess
import hashlib
import shutil
import warnings
from collections import OrderedDict
from pathlib import Path
import torch
from mmcv import digit_version
import mmpretrain
def parse_args():
@ -13,23 +17,72 @@ def parse_args():
description='Process a checkpoint to be published')
parser.add_argument('in_file', help='input checkpoint filename')
parser.add_argument('out_file', help='output checkpoint filename')
parser.add_argument(
'--no-ema',
action='store_true',
help='Use keys in `ema_state_dict` (no-ema keys).')
parser.add_argument(
'--dataset-type',
type=str,
help='The type of the dataset. If the checkpoint is converted '
'from other repository, this option is used to fill the dataset '
'meta information to the published checkpoint, like "ImageNet", '
'"CIFAR10" and others.')
args = parser.parse_args()
return args
def process_checkpoint(in_file, out_file):
def process_checkpoint(in_file, out_file, args):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
if digit_version(torch.__version__) >= digit_version('1.6'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
# remove unnecessary fields for smaller file size
for key in ['optimizer', 'param_schedulers', 'hook_msgs', 'message_hub']:
checkpoint.pop(key, None)
sha = subprocess.check_output(['sha256sum', out_file]).decode()
# For checkpoint converted from the official weight
if 'state_dict' not in checkpoint:
checkpoint = dict(state_dict=checkpoint)
meta = checkpoint.get('meta', {})
meta.setdefault('mmpretrain_version', mmpretrain.__version__)
# handle dataset meta information
if args.dataset_type is not None:
from mmpretrain.registry import DATASETS
dataset_class = DATASETS.get(args.dataset_type)
dataset_meta = getattr(dataset_class, 'METAINFO', {})
else:
dataset_meta = {}
meta.setdefault('dataset_meta', dataset_meta)
if len(meta['dataset_meta']) == 0:
warnings.warn('Missing dataset meta information.')
checkpoint['meta'] = meta
ema_state_dict = OrderedDict()
if 'ema_state_dict' in checkpoint:
for k, v in checkpoint['ema_state_dict'].items():
# The ema static dict has some extra fields
if k.startswith('module.'):
origin_k = k[len('module.'):]
assert origin_k in checkpoint['state_dict']
ema_state_dict[origin_k] = v
del checkpoint['ema_state_dict']
print('The input checkpoint has EMA weights, ', end='')
if args.no_ema:
# The values stored in `ema_state_dict` is original values.
print('and drop the EMA weights.')
assert ema_state_dict.keys() <= checkpoint['state_dict'].keys()
checkpoint['state_dict'].update(ema_state_dict)
else:
print('and use the EMA weights.')
temp_out_file = Path(out_file).with_name('temp_' + Path(out_file).name)
torch.save(checkpoint, temp_out_file)
with open(temp_out_file, 'rb') as f:
sha = hashlib.sha256(f.read()).hexdigest()[:8]
if out_file.endswith('.pth'):
out_file_name = out_file[:-4]
else:
@ -37,7 +90,7 @@ def process_checkpoint(in_file, out_file):
current_date = datetime.datetime.now().strftime('%Y%m%d')
final_file = out_file_name + f'_{current_date}-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])
shutil.move(temp_out_file, final_file)
print(f'Successfully generated the publish-ckpt as {final_file}.')
@ -48,7 +101,7 @@ def main():
if not out_dir.exists():
raise ValueError(f'Directory {out_dir} does not exist, '
'please generate it manually.')
process_checkpoint(args.in_file, args.out_file)
process_checkpoint(args.in_file, args.out_file, args)
if __name__ == '__main__':

View File

@ -3,24 +3,24 @@ import argparse
import copy
import math
import pkg_resources
import re
from functools import partial
from pathlib import Path
import mmcv
import numpy as np
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.config import Config, DictAction
from mmengine.dataset import default_collate
from mmengine.registry import init_default_scope
from mmengine.utils import to_2tuple
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm
from mmengine.utils.dl_utils import is_norm
from mmpretrain import digit_version
from mmpretrain.apis import init_model
from mmpretrain.apis import get_model
from mmpretrain.registry import TRANSFORMS
try:
from pytorch_grad_cam import (EigenCAM, EigenGradCAM, GradCAM,
GradCAMPlusPlus, LayerCAM, XGradCAM)
import pytorch_grad_cam as cam
from pytorch_grad_cam.activations_and_gradients import \
ActivationsAndGradients
from pytorch_grad_cam.utils.image import show_cam_on_image
@ -28,15 +28,14 @@ except ImportError:
raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install '
'3rd party package pytorch_grad_cam.')
# Supported grad-cam type map
# Alias name
METHOD_MAP = {
'gradcam': GradCAM,
'gradcam++': GradCAMPlusPlus,
'xgradcam': XGradCAM,
'eigencam': EigenCAM,
'eigengradcam': EigenGradCAM,
'layercam': LayerCAM,
'gradcam++': cam.GradCAMPlusPlus,
}
METHOD_MAP.update({
cam_class.__name__.lower(): cam_class
for cam_class in cam.base_cam.BaseCAM.__subclasses__()
})
def parse_args():
@ -113,33 +112,21 @@ def parse_args():
return args
def build_reshape_transform(model, args):
def reshape_transform(tensor, model, args):
"""Build reshape_transform for `cam.activations_and_grads`, which is
necessary for ViT-like networks."""
# ViT_based_Transformers have an additional clstoken in features
if not args.vit_like:
def check_shape(tensor):
assert len(tensor.size()) != 3, \
(f"The input feature's shape is {tensor.size()}, and it seems "
'to have been flattened or from a vit-like network. '
"Please use `--vit-like` if it's from a vit-like network.")
if tensor.ndim == 4:
# For (B, C, H, W)
return tensor
elif tensor.ndim == 3:
if not args.vit_like:
raise ValueError(f"The tensor shape is {tensor.shape}, if it's a "
'vit-like backbone, please specify `--vit-like`.')
# For (B, L, C)
num_extra_tokens = args.num_extra_tokens or getattr(
model.backbone, 'num_extra_tokens', 1)
return check_shape
if args.num_extra_tokens is not None:
num_extra_tokens = args.num_extra_tokens
elif hasattr(model.backbone, 'num_extra_tokens'):
num_extra_tokens = model.backbone.num_extra_tokens
else:
num_extra_tokens = 1
def _reshape_transform(tensor):
"""reshape_transform helper."""
assert len(tensor.size()) == 3, \
(f"The input feature's shape is {tensor.size()}, "
'and the feature seems not from a vit-like network?')
tensor = tensor[:, num_extra_tokens:, :]
# get heat_map_height and heat_map_width, preset input is a square
heat_map_area = tensor.size()[1]
@ -149,13 +136,13 @@ def build_reshape_transform(model, args):
f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},'
' which is not a perfect square number. Please check if you used '
'a wrong num-extra-tokens.')
# (B, L, C) -> (B, H, W, C)
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
# Bring the channels to the first dimension, like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
# (B, H, W, C) -> (B, C, H, W)
result = result.permute(0, 3, 1, 2)
return result
return _reshape_transform
else:
raise ValueError(f'Unsupported tensor shape {tensor.shape}.')
def init_cam(method, model, target_layers, use_cuda, reshape_transform):
@ -175,37 +162,12 @@ def init_cam(method, model, target_layers, use_cuda, reshape_transform):
def get_layer(layer_str, model):
"""get model layer from given str."""
cur_layer = model
layer_names = layer_str.strip().split('.')
def get_children_by_name(model, name):
try:
return getattr(model, name)
except AttributeError as e:
for name, layer in model.named_modules():
if name == layer_str:
return layer
raise AttributeError(
e.args[0] +
'. Please use `--preview-model` to check keys at first.')
def get_children_by_eval(model, name):
try:
return eval(f'model{name}', {}, {'model': model})
except (AttributeError, IndexError) as e:
raise AttributeError(
e.args[0] +
'. Please use `--preview-model` to check keys at first.')
for layer_name in layer_names:
match_res = re.match('(?P<name>.+?)(?P<indices>(\\[.+\\])+)',
layer_name)
if match_res:
layer_name = match_res.groupdict()['name']
indices = match_res.groupdict()['indices']
cur_layer = get_children_by_name(cur_layer, layer_name)
cur_layer = get_children_by_eval(cur_layer, indices)
else:
cur_layer = get_children_by_name(cur_layer, layer_name)
return cur_layer
f'Cannot get the layer "{layer_str}". Please choose from: \n' +
'\n'.join(name for name, _ in model.named_modules()))
def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
@ -224,39 +186,32 @@ def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
def get_default_traget_layers(model, args):
"""get default target layers from given model, here choose nrom type layer
as default target layer."""
norm_layers = []
for m in model.backbone.modules():
if isinstance(m, (BatchNorm2d, LayerNorm, GroupNorm, BatchNorm1d)):
norm_layers.append(m)
if len(norm_layers) == 0:
raise ValueError(
'`--target-layers` is empty. Please use `--preview-model`'
' to check keys at first and then specify `target-layers`.')
# if the model is CNN model or Swin model, just use the last norm
# layer as the target-layer, if the model is ViT model, the final
# classification is done on the class token computed in the last
# attention block, the output will not be affected by the 14x14
# channels in the last layer. The gradient of the output with
# respect to them, will be 0! here use the last 3rd norm layer.
# means the first norm of the last decoder block.
norm_layers = [
(name, layer)
for name, layer in model.backbone.named_modules(prefix='backbone')
if is_norm(layer)
]
if args.vit_like:
if args.num_extra_tokens:
num_extra_tokens = args.num_extra_tokens
elif hasattr(model.backbone, 'num_extra_tokens'):
num_extra_tokens = model.backbone.num_extra_tokens
else:
raise AttributeError('Please set num_extra_tokens in backbone'
" or using 'num-extra-tokens'")
# For ViT models, the final classification is done on the class token.
# And the patch tokens and class tokens won't interact each other after
# the final attention layer. Therefore, we need to choose the norm
# layer before the last attention layer.
num_extra_tokens = args.num_extra_tokens or getattr(
model.backbone, 'num_extra_tokens', 1)
# if a vit-like backbone's num_extra_tokens bigger than 0, view it
# as a VisionTransformer backbone, eg. DeiT, T2T-ViT.
if num_extra_tokens >= 1:
out_type = getattr(model.backbone, 'out_type')
if out_type == 'cls_token' or num_extra_tokens > 0:
# Assume the backbone feature is class token.
name, layer = norm_layers[-3]
print('Automatically choose the last norm layer before the '
'final attention block as target_layer..')
return [norm_layers[-3]]
print('Automatically choose the last norm layer as target_layer.')
target_layers = [norm_layers[-1]]
return target_layers
f'final attention block "{name}" as the target layer.')
return [layer]
# For CNN models, use the last norm layer as the target-layer
name, layer = norm_layers[-1]
print('Automatically choose the last norm layer '
f'"{name}" as the target layer.')
return [layer]
def main():
@ -265,16 +220,16 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope('mmpretrain')
# build the model from a config file and a checkpoint file
model = init_model(cfg, args.checkpoint, device=args.device)
model: nn.Module = get_model(cfg, args.checkpoint, device=args.device)
if args.preview_model:
print(model)
print('\n Please remove `--preview-model` to get the CAM.')
return
# apply transform and perpare data
transforms = Compose(cfg.test_dataloader.dataset.pipeline)
transforms = Compose(
[TRANSFORMS.build(t) for t in cfg.test_dataloader.dataset.pipeline])
data = transforms({'img_path': args.img})
src_img = copy.deepcopy(data['inputs']).numpy().transpose(1, 2, 0)
data = model.data_preprocessor(default_collate([data]), False)
@ -289,9 +244,8 @@ def main():
# init a cam grad calculator
use_cuda = ('cuda' in args.device)
reshape_transform = build_reshape_transform(model, args)
cam = init_cam(args.method, model, target_layers, use_cuda,
reshape_transform)
partial(reshape_transform, model=model, args=args))
# warp the target_category with ClassifierOutputTarget in grad_cam>=1.3.7,
# to fix the bug in #654.

View File

@ -108,7 +108,10 @@ def parse_args():
'WARNING.')
parser.add_argument('--title', type=str, help='title of figure')
parser.add_argument(
'--style', type=str, default='whitegrid', help='style of plt')
'--style',
type=str,
default='whitegrid',
help='style of the figure, need `seaborn` package.')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--window-size',