[Refactor] Update almost tools and add unit tests for these tools. (#1418)
* [Refactor] Update almost tools and add unit tests for these tools. * Fix Windows UT.pull/1424/head
parent
8875e9da92
commit
4f5b38f225
.circleci
demo
requirements
tests
data
test_datasets/test_transforms
tools
model_converters
visualization
|
@ -200,8 +200,8 @@ workflows:
|
||||||
- dev-1.x
|
- dev-1.x
|
||||||
- build_cpu:
|
- build_cpu:
|
||||||
name: minimum_version_cpu
|
name: minimum_version_cpu
|
||||||
torch: 1.6.0
|
torch: 1.8.0
|
||||||
torchvision: 0.7.0
|
torchvision: 0.9.0
|
||||||
python: 3.7.16
|
python: 3.7.16
|
||||||
requires:
|
requires:
|
||||||
- lint
|
- lint
|
||||||
|
@ -231,11 +231,11 @@ workflows:
|
||||||
jobs:
|
jobs:
|
||||||
- build_cuda:
|
- build_cuda:
|
||||||
name: minimum_version_gpu
|
name: minimum_version_gpu
|
||||||
torch: 1.6.0
|
torch: 1.8.0
|
||||||
# Use double quotation mark to explicitly specify its type
|
# Use double quotation mark to explicitly specify its type
|
||||||
# as string instead of number
|
# as string instead of number
|
||||||
cuda: "10.1"
|
cuda: "10.2"
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
- dev-1.x
|
- pretrain
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
_base_ = [
|
_base_ = [
|
||||||
'../_base_/models/vit-base-p16.py',
|
'../_base_/models/vit-base-p16.py',
|
||||||
'../_base_/datasets/imagenet_bs64_clip_384.py',
|
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||||
'../_base_/default_runtime.py'
|
'../_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
|
# model setting
|
||||||
model = dict(backbone=dict(pre_norm=True, ), )
|
model = dict(backbone=dict(pre_norm=True))
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=384,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=384,
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=384),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
# schedule setting
|
# schedule setting
|
||||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
_base_ = [
|
_base_ = [
|
||||||
'../_base_/models/vit-base-p16.py',
|
'../_base_/models/vit-base-p16.py',
|
||||||
'../_base_/datasets/imagenet_bs64_clip_448.py',
|
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||||
'../_base_/default_runtime.py'
|
'../_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
|
# model setting
|
||||||
model = dict(backbone=dict(pre_norm=True, ), )
|
model = dict(backbone=dict(pre_norm=True))
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=448,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=448,
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=448),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
# schedule setting
|
# schedule setting
|
||||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
_base_ = [
|
_base_ = [
|
||||||
'../_base_/models/vit-base-p16.py',
|
'../_base_/models/vit-base-p16.py',
|
||||||
'../_base_/datasets/imagenet_bs64_clip_224.py',
|
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||||
'../_base_/default_runtime.py'
|
'../_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
|
# model setting
|
||||||
model = dict(backbone=dict(pre_norm=True, ), )
|
model = dict(backbone=dict(pre_norm=True))
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=224,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=224,
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
# schedule setting
|
# schedule setting
|
||||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
_base_ = [
|
_base_ = [
|
||||||
'../_base_/models/vit-base-p32.py',
|
'../_base_/models/vit-base-p32.py',
|
||||||
'../_base_/datasets/imagenet_bs64_clip_384.py',
|
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||||
'../_base_/default_runtime.py'
|
'../_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
|
# model setting
|
||||||
model = dict(backbone=dict(pre_norm=True, ), )
|
model = dict(backbone=dict(pre_norm=True))
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=384,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=384,
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=384),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
# schedule setting
|
# schedule setting
|
||||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
_base_ = [
|
_base_ = [
|
||||||
'../_base_/models/vit-base-p32.py',
|
'../_base_/models/vit-base-p32.py',
|
||||||
'../_base_/datasets/imagenet_bs64_clip_448.py',
|
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||||
'../_base_/default_runtime.py'
|
'../_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
|
# model setting
|
||||||
model = dict(backbone=dict(pre_norm=True, ), )
|
model = dict(backbone=dict(pre_norm=True))
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=448,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=448,
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=448),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
# schedule setting
|
# schedule setting
|
||||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
_base_ = [
|
_base_ = [
|
||||||
'../_base_/models/vit-base-p32.py',
|
'../_base_/models/vit-base-p32.py',
|
||||||
'../_base_/datasets/imagenet_bs64_clip_224.py',
|
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||||
'../_base_/default_runtime.py'
|
'../_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
# model setting/mnt/lustre/lirongjie/tmp/clip_ckpt/trans_ckpt
|
# model setting
|
||||||
model = dict(backbone=dict(pre_norm=True, ), )
|
model = dict(backbone=dict(pre_norm=True))
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=224,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=224,
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='PackInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
# schedule setting
|
# schedule setting
|
||||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||||
|
|
|
@ -25,7 +25,9 @@ def main():
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
try:
|
try:
|
||||||
inferencer = ImageClassificationInferencer(args.model, args.checkpoint)
|
pretrained = args.checkpoint or True
|
||||||
|
inferencer = ImageClassificationInferencer(
|
||||||
|
args.model, pretrained=pretrained)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Unavailable model "{args.model}", you can specify find a model '
|
f'Unavailable model "{args.model}", you can specify find a model '
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
albumentations>=0.3.2 --no-binary qudida,albumentations
|
albumentations>=0.3.2 --no-binary qudida,albumentations # For Albumentations data transform
|
||||||
colorama
|
grad-cam >= 1.3.7 # For CAM visualization
|
||||||
requests
|
requests # For torchserve
|
||||||
|
scikit-learn # For t-SNE visualization and unit tests.
|
||||||
|
|
|
@ -1,9 +1,3 @@
|
||||||
codecov
|
codecov
|
||||||
flake8
|
|
||||||
interrogate
|
interrogate
|
||||||
isort==4.3.21
|
|
||||||
mmdet>=3.0.0rc0
|
|
||||||
pytest
|
pytest
|
||||||
scikit-learn
|
|
||||||
xdoctest >= 0.10.0
|
|
||||||
yapf
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
../../color.jpg
|
|
@ -0,0 +1 @@
|
||||||
|
../../color.jpg
|
|
@ -0,0 +1 @@
|
||||||
|
../../../color.jpg
|
|
@ -1,10 +0,0 @@
|
||||||
{"a": "b"}
|
|
||||||
{"mode": "train", "epoch": 1, "iter": 10, "lr": 0.01309, "memory": 0, "data_time": 0.0072, "time": 0.00727}
|
|
||||||
{"mode": "train", "epoch": 1, "iter": 20, "lr": 0.02764, "memory": 0, "data_time": 0.00044, "time": 0.00046}
|
|
||||||
{"mode": "train", "epoch": 1, "iter": 30, "lr": 0.04218, "memory": 0, "data_time": 0.00028, "time": 0.0003}
|
|
||||||
{"mode": "train", "epoch": 1, "iter": 40, "lr": 0.05673, "memory": 0, "data_time": 0.00027, "time": 0.00029}
|
|
||||||
{"mode": "train", "epoch": 2, "iter": 10, "lr": 0.17309, "memory": 0, "data_time": 0.00048, "time": 0.0005}
|
|
||||||
{"mode": "train", "epoch": 2, "iter": 20, "lr": 0.18763, "memory": 0, "data_time": 0.00038, "time": 0.0004}
|
|
||||||
{"mode": "train", "epoch": 2, "iter": 30, "lr": 0.20218, "memory": 0, "data_time": 0.00037, "time": 0.00039}
|
|
||||||
{"mode": "train", "epoch": 3, "iter": 10, "lr": 0.33305, "memory": 0, "data_time": 0.00045, "time": 0.00046}
|
|
||||||
{"mode": "train", "epoch": 3, "iter": 20, "lr": 0.34759, "memory": 0, "data_time": 0.0003, "time": 0.00032}
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
{"lr": 0.1, "data_time": 0.0061125516891479496, "loss": 2.6531384229660033, "time": 0.14429793357849122, "epoch": 1, "step": 10}
|
||||||
|
{"lr": 0.1, "data_time": 0.00030262470245361327, "loss": 2.9456406116485594, "time": 0.0219132661819458, "epoch": 1, "step": 20}
|
||||||
|
{"lr": 0.1, "data_time": 0.00022499561309814454, "loss": 3.1025198698043823, "time": 0.021793675422668458, "epoch": 1, "step": 30}
|
||||||
|
{"lr": 0.1, "data_time": 0.00023109912872314452, "loss": 2.5765398740768433, "time": 0.021819329261779784, "epoch": 1, "step": 40}
|
||||||
|
{"lr": 0.1, "data_time": 0.00023169517517089843, "loss": 2.671005058288574, "time": 0.02181088924407959, "epoch": 1, "step": 50}
|
||||||
|
{"lr": 0.1, "data_time": 0.00021798610687255858, "loss": 2.5273321866989136, "time": 0.021781444549560547, "epoch": 1, "step": 60}
|
||||||
|
{"accuracy/top1": 18.80000114440918, "step": 1}
|
||||||
|
{"lr": 0.1, "data_time": 0.0007575273513793946, "loss": 2.3254310727119445, "time": 0.02237672805786133, "epoch": 2, "step": 73}
|
||||||
|
{"lr": 0.1, "data_time": 0.0002459049224853516, "loss": 2.194095492362976, "time": 0.021792054176330566, "epoch": 2, "step": 83}
|
||||||
|
{"lr": 0.1, "data_time": 0.00027666091918945315, "loss": 2.207821953296661, "time": 0.021822547912597655, "epoch": 2, "step": 93}
|
||||||
|
{"lr": 0.1, "data_time": 0.00025298595428466795, "loss": 2.090667963027954, "time": 0.02178535461425781, "epoch": 2, "step": 103}
|
||||||
|
{"lr": 0.1, "data_time": 0.0002483367919921875, "loss": 2.18342148065567, "time": 0.021893739700317383, "epoch": 2, "step": 113}
|
||||||
|
{"lr": 0.1, "data_time": 0.00030078887939453123, "loss": 2.2274346113204957, "time": 0.022345948219299316, "epoch": 2, "step": 123}
|
||||||
|
{"accuracy/top1": 21.100000381469727, "step": 2}
|
||||||
|
{"lr": 0.1, "data_time": 0.0008128643035888672, "loss": 2.017984461784363, "time": 0.02267434597015381, "epoch": 3, "step": 136}
|
||||||
|
{"lr": 0.1, "data_time": 0.00023736953735351563, "loss": 2.0648953437805178, "time": 0.02174344062805176, "epoch": 3, "step": 146}
|
||||||
|
{"lr": 0.1, "data_time": 0.00024063587188720702, "loss": 2.0859395623207093, "time": 0.022107195854187012, "epoch": 3, "step": 156}
|
||||||
|
{"lr": 0.1, "data_time": 0.0002336740493774414, "loss": 2.1662048220634462, "time": 0.021825361251831054, "epoch": 3, "step": 166}
|
||||||
|
{"lr": 0.1, "data_time": 0.0002296924591064453, "loss": 2.1007142066955566, "time": 0.021821355819702147, "epoch": 3, "step": 176}
|
||||||
|
{"lr": 0.1, "data_time": 0.00023157596588134765, "loss": 2.0436240792274476, "time": 0.021722936630249025, "epoch": 3, "step": 186}
|
||||||
|
{"accuracy/top1": 25.600000381469727, "step": 3}
|
|
@ -5,12 +5,17 @@ import random
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from unittest.mock import ANY, call, patch
|
from unittest.mock import ANY, call, patch
|
||||||
|
|
||||||
import albumentations
|
|
||||||
import mmengine
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mmpretrain.registry import TRANSFORMS
|
from mmpretrain.registry import TRANSFORMS
|
||||||
|
|
||||||
|
try:
|
||||||
|
import albumentations
|
||||||
|
except ImportError:
|
||||||
|
albumentations = None
|
||||||
|
|
||||||
|
|
||||||
def construct_toy_data():
|
def construct_toy_data():
|
||||||
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||||
|
@ -666,6 +671,8 @@ class TestAlbumentations(TestCase):
|
||||||
DEFAULT_ARGS = dict(
|
DEFAULT_ARGS = dict(
|
||||||
type='Albumentations', transforms=[dict(type='ChannelShuffle', p=1)])
|
type='Albumentations', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
albumentations is None, reason='No Albumentations module.')
|
||||||
def test_assertion(self):
|
def test_assertion(self):
|
||||||
# Test with non-list transforms
|
# Test with non-list transforms
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
|
@ -697,6 +704,8 @@ class TestAlbumentations(TestCase):
|
||||||
cfg['keymap'] = []
|
cfg['keymap'] = []
|
||||||
TRANSFORMS.build(cfg)
|
TRANSFORMS.build(cfg)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
albumentations is None, reason='No Albumentations module.')
|
||||||
def test_transform(self):
|
def test_transform(self):
|
||||||
ori_img = np.random.randint(0, 256, (256, 256, 3), np.uint8)
|
ori_img = np.random.randint(0, 256, (256, 256, 3), np.uint8)
|
||||||
results = dict(img=copy.deepcopy(ori_img))
|
results = dict(img=copy.deepcopy(ori_img))
|
||||||
|
@ -795,6 +804,8 @@ class TestAlbumentations(TestCase):
|
||||||
assert min(ablu_result['img'].shape[:2]) == 400
|
assert min(ablu_result['img'].shape[:2]) == 400
|
||||||
assert ablu_result['img_shape'] == (400, 600)
|
assert ablu_result['img_shape'] == (400, 600)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
albumentations is None, reason='No Albumentations module.')
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||||
transform = TRANSFORMS.build(cfg)
|
transform = TRANSFORMS.build(cfg)
|
||||||
|
|
|
@ -0,0 +1,418 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
from subprocess import PIPE, Popen
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import torch
|
||||||
|
from mmengine.config import Config
|
||||||
|
|
||||||
|
from mmpretrain import ModelHub, get_model
|
||||||
|
from mmpretrain.structures import DataSample
|
||||||
|
|
||||||
|
MMPRE_ROOT = Path(__file__).parent.parent
|
||||||
|
ASSETS_ROOT = Path(__file__).parent / 'data'
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageDemo(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'demo/image_demo.py',
|
||||||
|
'demo/demo.JPEG',
|
||||||
|
'mobilevit-xxsmall_3rdparty_in1k',
|
||||||
|
'--device',
|
||||||
|
'cpu',
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn('sea snake', out.decode())
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyzeLogs(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.log_file = ASSETS_ROOT / 'vis_data.json'
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.out_file = Path(self.tmpdir.name) / 'out.png'
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/analysis_tools/analyze_logs.py',
|
||||||
|
'cal_train_time',
|
||||||
|
self.log_file,
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn('slowest epoch 2, average time is 0.0219', out.decode())
|
||||||
|
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/analysis_tools/analyze_logs.py',
|
||||||
|
'plot_curve',
|
||||||
|
self.log_file,
|
||||||
|
'--keys',
|
||||||
|
'accuracy/top1',
|
||||||
|
'--out',
|
||||||
|
str(self.out_file),
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn(str(self.log_file), out.decode())
|
||||||
|
self.assertIn(str(self.out_file), out.decode())
|
||||||
|
self.assertTrue(self.out_file.exists())
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyzeResults(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
dataset_cfg = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
data_root=str(ASSETS_ROOT / 'dataset'),
|
||||||
|
)
|
||||||
|
config = Config(dict(test_dataloader=dict(dataset=dataset_cfg)))
|
||||||
|
self.config_file = self.dir / 'config.py'
|
||||||
|
config.dump(self.config_file)
|
||||||
|
|
||||||
|
results = [{
|
||||||
|
'gt_label': 1,
|
||||||
|
'pred_label': 0,
|
||||||
|
'pred_score': [0.9, 0.1],
|
||||||
|
'sample_idx': 0,
|
||||||
|
}, {
|
||||||
|
'gt_label': 0,
|
||||||
|
'pred_label': 0,
|
||||||
|
'pred_score': [0.9, 0.1],
|
||||||
|
'sample_idx': 1,
|
||||||
|
}]
|
||||||
|
self.result_file = self.dir / 'results.pkl'
|
||||||
|
mmengine.dump(results, self.result_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/analysis_tools/analyze_results.py',
|
||||||
|
self.config_file,
|
||||||
|
self.result_file,
|
||||||
|
'--out-dir',
|
||||||
|
self.tmpdir.name,
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
p.communicate()
|
||||||
|
self.assertTrue((self.dir / 'success/2.jpeg.png').exists())
|
||||||
|
self.assertTrue((self.dir / 'fail/1.JPG.png').exists())
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrintConfig(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.config_file = MMPRE_ROOT / 'configs/resnet/resnet18_8xb32_in1k.py'
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/misc/print_config.py',
|
||||||
|
self.config_file,
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
out = out.decode().strip().replace('\r\n', '\n')
|
||||||
|
self.assertEqual(out,
|
||||||
|
Config.fromfile(self.config_file).pretty_text.strip())
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerifyDataset(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
dataset_cfg = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
ann_file=str(self.dir / 'ann.txt'),
|
||||||
|
pipeline=[dict(type='LoadImageFromFile')],
|
||||||
|
data_root=str(ASSETS_ROOT / 'dataset'),
|
||||||
|
)
|
||||||
|
ann_file = '\n'.join(['a/2.JPG 0', 'b/2.jpeg 1', 'b/subb/3.jpg 1'])
|
||||||
|
(self.dir / 'ann.txt').write_text(ann_file)
|
||||||
|
config = Config(dict(train_dataloader=dict(dataset=dataset_cfg)))
|
||||||
|
self.config_file = Path(self.tmpdir.name) / 'config.py'
|
||||||
|
config.dump(self.config_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/misc/verify_dataset.py',
|
||||||
|
self.config_file,
|
||||||
|
'--out-path',
|
||||||
|
self.dir / 'log.log',
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn(
|
||||||
|
f"{ASSETS_ROOT/'dataset/a/2.JPG'} cannot be read correctly",
|
||||||
|
out.decode().strip())
|
||||||
|
self.assertTrue((self.dir / 'log.log').exists())
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvalMetric(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
DataSample().set_gt_label(1).set_pred_label(0).set_pred_score(
|
||||||
|
[0.6, 0.3, 0.1]).to_dict(),
|
||||||
|
DataSample().set_gt_label(0).set_pred_label(0).set_pred_score(
|
||||||
|
[0.6, 0.3, 0.1]).to_dict(),
|
||||||
|
]
|
||||||
|
self.result_file = self.dir / 'results.pkl'
|
||||||
|
mmengine.dump(results, self.result_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/analysis_tools/eval_metric.py',
|
||||||
|
self.result_file,
|
||||||
|
'--metric',
|
||||||
|
'type=Accuracy',
|
||||||
|
'topk=1,2',
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn('accuracy/top1', out.decode())
|
||||||
|
self.assertIn('accuracy/top2', out.decode())
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisScheduler(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
config = Config.fromfile(MMPRE_ROOT /
|
||||||
|
'configs/resnet/resnet18_8xb32_in1k.py')
|
||||||
|
config.param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='LinearLR',
|
||||||
|
start_factor=0.01,
|
||||||
|
by_epoch=True,
|
||||||
|
end=1,
|
||||||
|
convert_to_iter_based=True),
|
||||||
|
dict(type='CosineAnnealingLR', by_epoch=True, begin=1),
|
||||||
|
]
|
||||||
|
config.work_dir = str(self.dir)
|
||||||
|
config.train_cfg.max_epochs = 2
|
||||||
|
self.config_file = Path(self.tmpdir.name) / 'config.py'
|
||||||
|
config.dump(self.config_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/visualization/vis_scheduler.py',
|
||||||
|
self.config_file,
|
||||||
|
'--dataset-size',
|
||||||
|
'100',
|
||||||
|
'--not-show',
|
||||||
|
'--save-path',
|
||||||
|
str(self.dir / 'out.png'),
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
p.communicate()
|
||||||
|
self.assertTrue((self.dir / 'out.png').exists())
|
||||||
|
|
||||||
|
|
||||||
|
class TestPublishModel(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
ckpt = dict(
|
||||||
|
state_dict=OrderedDict({
|
||||||
|
'a': torch.tensor(1.),
|
||||||
|
}),
|
||||||
|
ema_state_dict=OrderedDict({
|
||||||
|
'step': 1,
|
||||||
|
'module.a': torch.tensor(2.),
|
||||||
|
}))
|
||||||
|
self.ckpt_file = self.dir / 'ckpt.pth'
|
||||||
|
torch.save(ckpt, self.ckpt_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/model_converters/publish_model.py',
|
||||||
|
self.ckpt_file,
|
||||||
|
self.ckpt_file,
|
||||||
|
'--dataset-type',
|
||||||
|
'ImageNet',
|
||||||
|
'--no-ema',
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn('and drop the EMA weights.', out.decode())
|
||||||
|
self.assertIn('Successfully generated', out.decode())
|
||||||
|
output_ckpt = re.findall(r'ckpt_\d{8}-\w{8}.pth', out.decode())
|
||||||
|
self.assertGreater(len(output_ckpt), 0)
|
||||||
|
output_ckpt = output_ckpt[0]
|
||||||
|
self.assertTrue((self.dir / output_ckpt).exists())
|
||||||
|
# The input file won't be overridden.
|
||||||
|
self.assertTrue(self.ckpt_file.exists())
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisCam(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
model = get_model('mobilevit-xxsmall_3rdparty_in1k')
|
||||||
|
self.config_file = self.dir / 'config.py'
|
||||||
|
model.config.dump(self.config_file)
|
||||||
|
|
||||||
|
self.ckpt_file = self.dir / 'ckpt.pth'
|
||||||
|
torch.save(model.state_dict(), self.ckpt_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/visualization/vis_cam.py',
|
||||||
|
ASSETS_ROOT / 'color.jpg',
|
||||||
|
self.config_file,
|
||||||
|
self.ckpt_file,
|
||||||
|
'--save-path',
|
||||||
|
self.dir / 'cam.jpg',
|
||||||
|
]
|
||||||
|
p = Popen(command, cwd=MMPRE_ROOT, stdout=PIPE)
|
||||||
|
out, _ = p.communicate()
|
||||||
|
self.assertIn('backbone.conv_1x1_exp.bn', out.decode())
|
||||||
|
self.assertTrue((self.dir / 'cam.jpg').exists())
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfusionMatrix(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
self.config_file = MMPRE_ROOT / 'configs/resnet/resnet18_8xb32_in1k.py'
|
||||||
|
|
||||||
|
results = [
|
||||||
|
DataSample().set_gt_label(1).set_pred_label(0).set_pred_score(
|
||||||
|
[0.6, 0.3, 0.1]).to_dict(),
|
||||||
|
DataSample().set_gt_label(0).set_pred_label(0).set_pred_score(
|
||||||
|
[0.6, 0.3, 0.1]).to_dict(),
|
||||||
|
]
|
||||||
|
self.result_file = self.dir / 'results.pkl'
|
||||||
|
mmengine.dump(results, self.result_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/analysis_tools/confusion_matrix.py',
|
||||||
|
self.config_file,
|
||||||
|
self.result_file,
|
||||||
|
'--out',
|
||||||
|
self.dir / 'result.pkl',
|
||||||
|
]
|
||||||
|
Popen(command, cwd=MMPRE_ROOT, stdout=PIPE).wait()
|
||||||
|
result = mmengine.load(self.dir / 'result.pkl')
|
||||||
|
torch.testing.assert_allclose(
|
||||||
|
result, torch.tensor([
|
||||||
|
[1, 0, 0],
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, 0, 0],
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisTsne(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
self.dir = Path(self.tmpdir.name)
|
||||||
|
|
||||||
|
config = ModelHub.get('mobilevit-xxsmall_3rdparty_in1k').config
|
||||||
|
test_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
dataset=dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
data_root=str(ASSETS_ROOT / 'dataset'),
|
||||||
|
pipeline=config.test_dataloader.dataset.pipeline,
|
||||||
|
),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
)
|
||||||
|
config.test_dataloader = mmengine.ConfigDict(test_dataloader)
|
||||||
|
self.config_file = self.dir / 'config.py'
|
||||||
|
config.dump(self.config_file)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/visualization/vis_tsne.py',
|
||||||
|
self.config_file,
|
||||||
|
'--work-dir',
|
||||||
|
self.dir,
|
||||||
|
'--perplexity',
|
||||||
|
'2',
|
||||||
|
]
|
||||||
|
Popen(command, cwd=MMPRE_ROOT, stdout=PIPE).wait()
|
||||||
|
self.assertTrue(len(list(self.dir.glob('tsne_*/feat_*.png'))) > 0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFlops(TestCase):
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
command = [
|
||||||
|
'python',
|
||||||
|
'tools/analysis_tools/get_flops.py',
|
||||||
|
'mobilevit-xxsmall_3rdparty_in1k',
|
||||||
|
]
|
||||||
|
ret_code = Popen(command, cwd=MMPRE_ROOT).wait()
|
||||||
|
self.assertEqual(ret_code, 0)
|
|
@ -105,16 +105,12 @@ def plot_curve_helper(log_dicts, metrics, args, legend):
|
||||||
|
|
||||||
def plot_curve(log_dicts, args):
|
def plot_curve(log_dicts, args):
|
||||||
"""Plot train metric-iter graph."""
|
"""Plot train metric-iter graph."""
|
||||||
# set backend and style
|
# set style
|
||||||
if args.backend is not None:
|
|
||||||
plt.switch_backend(args.backend)
|
|
||||||
try:
|
try:
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
sns.set_style(args.style)
|
sns.set_style(args.style)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Attention: The plot style won't be applied because 'seaborn' "
|
pass
|
||||||
'package is not installed, please install it if you want better '
|
|
||||||
'show style.')
|
|
||||||
|
|
||||||
# set plot window size
|
# set plot window size
|
||||||
wind_w, wind_h = args.window_size.split('*')
|
wind_w, wind_h = args.window_size.split('*')
|
||||||
|
@ -161,9 +157,10 @@ def add_plot_parser(subparsers):
|
||||||
default=None,
|
default=None,
|
||||||
help='legend of each plot')
|
help='legend of each plot')
|
||||||
parser_plt.add_argument(
|
parser_plt.add_argument(
|
||||||
'--backend', type=str, default=None, help='backend of plt')
|
'--style',
|
||||||
parser_plt.add_argument(
|
type=str,
|
||||||
'--style', type=str, default='whitegrid', help='style of plt')
|
default='whitegrid',
|
||||||
|
help='style of the figure, need `seaborn` package.')
|
||||||
parser_plt.add_argument('--out', type=str, default=None)
|
parser_plt.add_argument('--out', type=str, default=None)
|
||||||
parser_plt.add_argument(
|
parser_plt.add_argument(
|
||||||
'--window-size',
|
'--window-size',
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
from mmengine import DictAction
|
from mmengine import DictAction
|
||||||
|
|
||||||
from mmpretrain.datasets import build_dataset
|
from mmpretrain.datasets import build_dataset
|
||||||
from mmpretrain.structures import ClsDataSample
|
from mmpretrain.structures import DataSample
|
||||||
from mmpretrain.visualization import UniversalVisualizer
|
from mmpretrain.visualization import UniversalVisualizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,8 @@ def parse_args():
|
||||||
description='MMCls evaluate prediction success/fail')
|
description='MMCls evaluate prediction success/fail')
|
||||||
parser.add_argument('config', help='test config file path')
|
parser.add_argument('config', help='test config file path')
|
||||||
parser.add_argument('result', help='test result json/pkl file')
|
parser.add_argument('result', help='test result json/pkl file')
|
||||||
parser.add_argument('--out-dir', help='dir to store output files')
|
parser.add_argument(
|
||||||
|
'--out-dir', required=True, help='dir to store output files')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--topk',
|
'--topk',
|
||||||
default=20,
|
default=20,
|
||||||
|
@ -51,15 +52,12 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
|
||||||
vis.dataset_meta = {'classes': dataset.CLASSES}
|
vis.dataset_meta = {'classes': dataset.CLASSES}
|
||||||
|
|
||||||
# save imgs
|
# save imgs
|
||||||
for result in results:
|
dump_infos = []
|
||||||
data_sample = ClsDataSample()\
|
for data_sample in results:
|
||||||
.set_gt_label(result['gt_label'])\
|
data_info = dataset.get_data_info(data_sample.sample_idx)
|
||||||
.set_pred_label(result['pred_label'])\
|
|
||||||
.set_pred_score(result['pred_scores'])
|
|
||||||
data_info = dataset.get_data_info(result['sample_idx'])
|
|
||||||
if 'img' in data_info:
|
if 'img' in data_info:
|
||||||
img = data_info['img']
|
img = data_info['img']
|
||||||
name = str(result['sample_idx'])
|
name = str(data_sample.sample_idx)
|
||||||
elif 'img_path' in data_info:
|
elif 'img_path' in data_info:
|
||||||
img = mmcv.imread(data_info['img_path'], channel_order='rgb')
|
img = mmcv.imread(data_info['img_path'], channel_order='rgb')
|
||||||
name = Path(data_info['img_path']).name
|
name = Path(data_info['img_path']).name
|
||||||
|
@ -70,19 +68,20 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
|
||||||
vis.visualize_cls(
|
vis.visualize_cls(
|
||||||
img, data_sample, out_file=osp.join(full_dir, name + '.png'))
|
img, data_sample, out_file=osp.join(full_dir, name + '.png'))
|
||||||
|
|
||||||
for k, v in result.items():
|
dump = dict()
|
||||||
|
for k, v in data_sample.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
result[k] = v.tolist()
|
dump[k] = v.tolist()
|
||||||
|
else:
|
||||||
|
dump[k] = v
|
||||||
|
dump_infos.append(dump)
|
||||||
|
|
||||||
mmengine.dump(results, osp.join(full_dir, folder_name + '.json'))
|
mmengine.dump(dump_infos, osp.join(full_dir, folder_name + '.json'))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# load test results
|
|
||||||
outputs = mmengine.load(args.result)
|
|
||||||
|
|
||||||
cfg = mmengine.Config.fromfile(args.config)
|
cfg = mmengine.Config.fromfile(args.config)
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
@ -91,27 +90,25 @@ def main():
|
||||||
cfg.test_dataloader.dataset.pipeline = []
|
cfg.test_dataloader.dataset.pipeline = []
|
||||||
dataset = build_dataset(cfg.test_dataloader.dataset)
|
dataset = build_dataset(cfg.test_dataloader.dataset)
|
||||||
|
|
||||||
outputs_list = list()
|
results = list()
|
||||||
for i in range(len(outputs)):
|
for result in mmengine.load(args.result):
|
||||||
output = dict()
|
data_sample = DataSample()
|
||||||
output['sample_idx'] = outputs[i]['sample_idx']
|
data_sample.set_metainfo({'sample_idx': result['sample_idx']})
|
||||||
output['gt_label'] = outputs[i]['gt_label']['label']
|
data_sample.set_gt_label(result['gt_label'])
|
||||||
output['pred_score'] = float(
|
data_sample.set_pred_label(result['pred_label'])
|
||||||
torch.max(outputs[i]['pred_label']['score']).item())
|
data_sample.set_pred_score(result['pred_score'])
|
||||||
output['pred_scores'] = outputs[i]['pred_label']['score']
|
results.append(data_sample)
|
||||||
output['pred_label'] = outputs[i]['pred_label']['label']
|
|
||||||
outputs_list.append(output)
|
|
||||||
|
|
||||||
# sort result
|
# sort result
|
||||||
outputs_list = sorted(outputs_list, key=lambda x: x['pred_score'])
|
results = sorted(results, key=lambda x: torch.max(x.pred_score))
|
||||||
|
|
||||||
success = list()
|
success = list()
|
||||||
fail = list()
|
fail = list()
|
||||||
for output in outputs_list:
|
for data_sample in results:
|
||||||
if output['pred_label'] == output['gt_label']:
|
if (data_sample.pred_label == data_sample.gt_label).all():
|
||||||
success.append(output)
|
success.append(data_sample)
|
||||||
else:
|
else:
|
||||||
fail.append(output)
|
fail.append(data_sample)
|
||||||
|
|
||||||
success = success[:args.topk]
|
success = success[:args.topk]
|
||||||
fail = fail[:args.topk]
|
fail = fail[:args.topk]
|
||||||
|
|
|
@ -29,7 +29,7 @@ def parse_args():
|
||||||
dest='metric_options',
|
dest='metric_options',
|
||||||
help='The metric config, the key-value pair in xxx=yyy format will be '
|
help='The metric config, the key-value pair in xxx=yyy format will be '
|
||||||
'parsed as the metric config items. You can specify multiple metrics '
|
'parsed as the metric config items. You can specify multiple metrics '
|
||||||
'by use multiple `--metric-options`. For list type value, you can use '
|
'by use multiple `--metric`. For list type value, you can use '
|
||||||
'"key=[a,b]" or "key=a,b", and it also allows nested list/tuple '
|
'"key=[a,b]" or "key=a,b", and it also allows nested list/tuple '
|
||||||
'values, e.g. "key=[(a,b),(c,d)]".')
|
'values, e.g. "key=[(a,b),(c,d)]".')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -1,18 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
from mmengine.analysis import get_model_complexity_info
|
||||||
|
|
||||||
try:
|
from mmpretrain import get_model
|
||||||
from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis,
|
|
||||||
flop_count_str, flop_count_table, parameter_count)
|
|
||||||
except ImportError:
|
|
||||||
print('You may need to install fvcore for flops computation, '
|
|
||||||
'and you can use `pip install fvcore` to set up the environment')
|
|
||||||
from fvcore.nn.print_model_statistics import _format_size
|
|
||||||
from mmengine import Config
|
|
||||||
|
|
||||||
from mmpretrain.models import build_classifier
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -29,9 +20,7 @@ def parse_args():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
if len(args.shape) == 1:
|
if len(args.shape) == 1:
|
||||||
input_shape = (3, args.shape[0], args.shape[0])
|
input_shape = (3, args.shape[0], args.shape[0])
|
||||||
elif len(args.shape) == 2:
|
elif len(args.shape) == 2:
|
||||||
|
@ -39,39 +28,30 @@ def main():
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid input shape')
|
raise ValueError('invalid input shape')
|
||||||
|
|
||||||
cfg = Config.fromfile(args.config)
|
model = get_model(args.config)
|
||||||
model = build_classifier(cfg.model)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if hasattr(model, 'extract_feat'):
|
if hasattr(model, 'extract_feat'):
|
||||||
model.forward = model.extract_feat
|
model.forward = model.extract_feat
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'FLOPs counter is currently not currently supported with {}'.
|
'FLOPs counter is currently not currently supported with {}'.
|
||||||
format(model.__class__.__name__))
|
format(model.__class__.__name__))
|
||||||
|
analysis_results = get_model_complexity_info(
|
||||||
inputs = (torch.randn((1, *input_shape)), )
|
model,
|
||||||
flops_ = FlopCountAnalysis(model, inputs)
|
input_shape,
|
||||||
activations_ = ActivationCountAnalysis(model, inputs)
|
|
||||||
|
|
||||||
flops = _format_size(flops_.total())
|
|
||||||
activations = _format_size(activations_.total())
|
|
||||||
params = _format_size(parameter_count(model)[''])
|
|
||||||
|
|
||||||
flop_table = flop_count_table(
|
|
||||||
flops=flops_,
|
|
||||||
activations=activations_,
|
|
||||||
show_param_shapes=True,
|
|
||||||
)
|
)
|
||||||
flop_str = flop_count_str(flops=flops_, activations=activations_)
|
flops = analysis_results['flops_str']
|
||||||
|
params = analysis_results['params_str']
|
||||||
print('\n' + flop_str)
|
activations = analysis_results['activations_str']
|
||||||
print('\n' + flop_table)
|
out_table = analysis_results['out_table']
|
||||||
|
out_arch = analysis_results['out_arch']
|
||||||
|
print(out_table)
|
||||||
|
print(out_arch)
|
||||||
split_line = '=' * 30
|
split_line = '=' * 30
|
||||||
print(f'{split_line}\nInput shape: {input_shape}\n'
|
print(f'{split_line}\nInput shape: {input_shape}\n'
|
||||||
f'Flops: {flops}\nParams: {params}\n'
|
f'Flops: {flops}\nParams: {params}\n'
|
||||||
f'Activation: {activations}\n{split_line}')
|
f'Activation: {activations}\n{split_line}')
|
||||||
|
print('!!!Only the backbone network is counted in FLOPs analysis.')
|
||||||
print('!!!Please be cautious if you use the results in papers. '
|
print('!!!Please be cautious if you use the results in papers. '
|
||||||
'You may need to check if all ops are supported and verify that the '
|
'You may need to check if all ops are supported and verify that the '
|
||||||
'flops computation is correct.')
|
'flops computation is correct.')
|
||||||
|
|
|
@ -166,7 +166,7 @@ def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
fold=fold,
|
fold=fold,
|
||||||
num_splits=num_splits,
|
num_splits=num_splits,
|
||||||
seed=cfg.seed,
|
seed=cfg.kfold_split_seed,
|
||||||
test_mode=test_mode,
|
test_mode=test_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -203,13 +203,13 @@ def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
|
||||||
class SaveInfoHook(Hook):
|
class SaveInfoHook(Hook):
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
def after_train_epoch(self, runner):
|
||||||
try:
|
last_ckpt = find_latest_checkpoint(cfg.work_dir)
|
||||||
last_ckpt = find_latest_checkpoint(cfg.work_dir)
|
exp_info = dict(
|
||||||
exp_info = dict(
|
fold=fold,
|
||||||
fold=fold, last_ckpt=last_ckpt, seed=runner.seed)
|
last_ckpt=last_ckpt,
|
||||||
dump(exp_info, osp.join(root_dir, EXP_INFO_FILE))
|
kfold_split_seed=cfg.kfold_split_seed,
|
||||||
except OSError:
|
)
|
||||||
pass
|
dump(exp_info, osp.join(root_dir, EXP_INFO_FILE))
|
||||||
|
|
||||||
runner.register_hook(SaveInfoHook(), 'LOWEST')
|
runner.register_hook(SaveInfoHook(), 'LOWEST')
|
||||||
|
|
||||||
|
@ -226,17 +226,14 @@ def main():
|
||||||
# merge cli arguments to config
|
# merge cli arguments to config
|
||||||
cfg = merge_args(cfg, args)
|
cfg = merge_args(cfg, args)
|
||||||
|
|
||||||
# set preprocess configs to model
|
|
||||||
cfg.model.setdefault('data_preprocessor', cfg.get('data_preprocessor', {}))
|
|
||||||
|
|
||||||
# set the unify random seed
|
# set the unify random seed
|
||||||
cfg.seed = args.seed or sync_random_seed()
|
cfg.kfold_split_seed = args.seed or sync_random_seed()
|
||||||
|
|
||||||
# resume from the previous experiment
|
# resume from the previous experiment
|
||||||
if args.resume:
|
if args.resume:
|
||||||
experiment_info = load(osp.join(cfg.work_dir, EXP_INFO_FILE))
|
experiment_info = load(osp.join(cfg.work_dir, EXP_INFO_FILE))
|
||||||
resume_fold = experiment_info['fold']
|
resume_fold = experiment_info['fold']
|
||||||
cfg.seed = experiment_info['seed']
|
cfg.kfold_split_seed = experiment_info['kfold_split_seed']
|
||||||
resume_ckpt = experiment_info.get('last_ckpt', None)
|
resume_ckpt = experiment_info.get('last_ckpt', None)
|
||||||
else:
|
else:
|
||||||
resume_fold = 0
|
resume_fold = 0
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import fcntl
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -10,6 +10,8 @@ from mmengine import (Config, DictAction, track_parallel_progress,
|
||||||
from mmpretrain.datasets import build_dataset
|
from mmpretrain.datasets import build_dataset
|
||||||
from mmpretrain.registry import TRANSFORMS
|
from mmpretrain.registry import TRANSFORMS
|
||||||
|
|
||||||
|
file_lock = multiprocessing.Lock()
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Verify Dataset')
|
parser = argparse.ArgumentParser(description='Verify Dataset')
|
||||||
|
@ -66,12 +68,11 @@ class DatasetValidator():
|
||||||
except Exception:
|
except Exception:
|
||||||
with open(self.log_file_path, 'a') as f:
|
with open(self.log_file_path, 'a') as f:
|
||||||
# add file lock to prevent multi-process writing errors
|
# add file lock to prevent multi-process writing errors
|
||||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
filepath = str(Path(item['img_path']))
|
||||||
filepath = os.path.join(item['img_prefix'],
|
file_lock.acquire()
|
||||||
item['img_info']['filename'])
|
|
||||||
f.write(filepath + '\n')
|
f.write(filepath + '\n')
|
||||||
|
file_lock.release()
|
||||||
print(f'{filepath} cannot be read correctly, please check it.')
|
print(f'{filepath} cannot be read correctly, please check it.')
|
||||||
# Release files lock automatic using with
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
@ -81,12 +82,12 @@ def print_info(log_file_path):
|
||||||
"""print some information and do extra action."""
|
"""print some information and do extra action."""
|
||||||
print()
|
print()
|
||||||
with open(log_file_path, 'r') as f:
|
with open(log_file_path, 'r') as f:
|
||||||
context = f.read().strip()
|
content = f.read().strip()
|
||||||
if context == '':
|
if content == '':
|
||||||
print('There is no broken file found.')
|
print('There is no broken file found.')
|
||||||
os.remove(log_file_path)
|
os.remove(log_file_path)
|
||||||
else:
|
else:
|
||||||
num_file = len(context.split('\n'))
|
num_file = len(content.split('\n'))
|
||||||
print(f'{num_file} broken files found, name list save in file:'
|
print(f'{num_file} broken files found, name list save in file:'
|
||||||
f'{log_file_path}')
|
f'{log_file_path}')
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import subprocess
|
import hashlib
|
||||||
|
import shutil
|
||||||
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv import digit_version
|
|
||||||
|
import mmpretrain
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -13,23 +17,72 @@ def parse_args():
|
||||||
description='Process a checkpoint to be published')
|
description='Process a checkpoint to be published')
|
||||||
parser.add_argument('in_file', help='input checkpoint filename')
|
parser.add_argument('in_file', help='input checkpoint filename')
|
||||||
parser.add_argument('out_file', help='output checkpoint filename')
|
parser.add_argument('out_file', help='output checkpoint filename')
|
||||||
|
parser.add_argument(
|
||||||
|
'--no-ema',
|
||||||
|
action='store_true',
|
||||||
|
help='Use keys in `ema_state_dict` (no-ema keys).')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset-type',
|
||||||
|
type=str,
|
||||||
|
help='The type of the dataset. If the checkpoint is converted '
|
||||||
|
'from other repository, this option is used to fill the dataset '
|
||||||
|
'meta information to the published checkpoint, like "ImageNet", '
|
||||||
|
'"CIFAR10" and others.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def process_checkpoint(in_file, out_file):
|
def process_checkpoint(in_file, out_file, args):
|
||||||
checkpoint = torch.load(in_file, map_location='cpu')
|
checkpoint = torch.load(in_file, map_location='cpu')
|
||||||
# remove optimizer for smaller file size
|
# remove unnecessary fields for smaller file size
|
||||||
if 'optimizer' in checkpoint:
|
for key in ['optimizer', 'param_schedulers', 'hook_msgs', 'message_hub']:
|
||||||
del checkpoint['optimizer']
|
checkpoint.pop(key, None)
|
||||||
# if it is necessary to remove some sensitive data in checkpoint['meta'],
|
|
||||||
# add the code here.
|
|
||||||
if digit_version(torch.__version__) >= digit_version('1.6'):
|
|
||||||
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
|
|
||||||
else:
|
|
||||||
torch.save(checkpoint, out_file)
|
|
||||||
|
|
||||||
sha = subprocess.check_output(['sha256sum', out_file]).decode()
|
# For checkpoint converted from the official weight
|
||||||
|
if 'state_dict' not in checkpoint:
|
||||||
|
checkpoint = dict(state_dict=checkpoint)
|
||||||
|
|
||||||
|
meta = checkpoint.get('meta', {})
|
||||||
|
meta.setdefault('mmpretrain_version', mmpretrain.__version__)
|
||||||
|
|
||||||
|
# handle dataset meta information
|
||||||
|
if args.dataset_type is not None:
|
||||||
|
from mmpretrain.registry import DATASETS
|
||||||
|
dataset_class = DATASETS.get(args.dataset_type)
|
||||||
|
dataset_meta = getattr(dataset_class, 'METAINFO', {})
|
||||||
|
else:
|
||||||
|
dataset_meta = {}
|
||||||
|
|
||||||
|
meta.setdefault('dataset_meta', dataset_meta)
|
||||||
|
|
||||||
|
if len(meta['dataset_meta']) == 0:
|
||||||
|
warnings.warn('Missing dataset meta information.')
|
||||||
|
|
||||||
|
checkpoint['meta'] = meta
|
||||||
|
|
||||||
|
ema_state_dict = OrderedDict()
|
||||||
|
if 'ema_state_dict' in checkpoint:
|
||||||
|
for k, v in checkpoint['ema_state_dict'].items():
|
||||||
|
# The ema static dict has some extra fields
|
||||||
|
if k.startswith('module.'):
|
||||||
|
origin_k = k[len('module.'):]
|
||||||
|
assert origin_k in checkpoint['state_dict']
|
||||||
|
ema_state_dict[origin_k] = v
|
||||||
|
del checkpoint['ema_state_dict']
|
||||||
|
print('The input checkpoint has EMA weights, ', end='')
|
||||||
|
if args.no_ema:
|
||||||
|
# The values stored in `ema_state_dict` is original values.
|
||||||
|
print('and drop the EMA weights.')
|
||||||
|
assert ema_state_dict.keys() <= checkpoint['state_dict'].keys()
|
||||||
|
checkpoint['state_dict'].update(ema_state_dict)
|
||||||
|
else:
|
||||||
|
print('and use the EMA weights.')
|
||||||
|
|
||||||
|
temp_out_file = Path(out_file).with_name('temp_' + Path(out_file).name)
|
||||||
|
torch.save(checkpoint, temp_out_file)
|
||||||
|
|
||||||
|
with open(temp_out_file, 'rb') as f:
|
||||||
|
sha = hashlib.sha256(f.read()).hexdigest()[:8]
|
||||||
if out_file.endswith('.pth'):
|
if out_file.endswith('.pth'):
|
||||||
out_file_name = out_file[:-4]
|
out_file_name = out_file[:-4]
|
||||||
else:
|
else:
|
||||||
|
@ -37,7 +90,7 @@ def process_checkpoint(in_file, out_file):
|
||||||
|
|
||||||
current_date = datetime.datetime.now().strftime('%Y%m%d')
|
current_date = datetime.datetime.now().strftime('%Y%m%d')
|
||||||
final_file = out_file_name + f'_{current_date}-{sha[:8]}.pth'
|
final_file = out_file_name + f'_{current_date}-{sha[:8]}.pth'
|
||||||
subprocess.Popen(['mv', out_file, final_file])
|
shutil.move(temp_out_file, final_file)
|
||||||
|
|
||||||
print(f'Successfully generated the publish-ckpt as {final_file}.')
|
print(f'Successfully generated the publish-ckpt as {final_file}.')
|
||||||
|
|
||||||
|
@ -48,7 +101,7 @@ def main():
|
||||||
if not out_dir.exists():
|
if not out_dir.exists():
|
||||||
raise ValueError(f'Directory {out_dir} does not exist, '
|
raise ValueError(f'Directory {out_dir} does not exist, '
|
||||||
'please generate it manually.')
|
'please generate it manually.')
|
||||||
process_checkpoint(args.in_file, args.out_file)
|
process_checkpoint(args.in_file, args.out_file, args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -3,24 +3,24 @@ import argparse
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import re
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
from mmcv.transforms import Compose
|
from mmcv.transforms import Compose
|
||||||
from mmengine.config import Config, DictAction
|
from mmengine.config import Config, DictAction
|
||||||
from mmengine.dataset import default_collate
|
from mmengine.dataset import default_collate
|
||||||
from mmengine.registry import init_default_scope
|
|
||||||
from mmengine.utils import to_2tuple
|
from mmengine.utils import to_2tuple
|
||||||
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm
|
from mmengine.utils.dl_utils import is_norm
|
||||||
|
|
||||||
from mmpretrain import digit_version
|
from mmpretrain import digit_version
|
||||||
from mmpretrain.apis import init_model
|
from mmpretrain.apis import get_model
|
||||||
|
from mmpretrain.registry import TRANSFORMS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pytorch_grad_cam import (EigenCAM, EigenGradCAM, GradCAM,
|
import pytorch_grad_cam as cam
|
||||||
GradCAMPlusPlus, LayerCAM, XGradCAM)
|
|
||||||
from pytorch_grad_cam.activations_and_gradients import \
|
from pytorch_grad_cam.activations_and_gradients import \
|
||||||
ActivationsAndGradients
|
ActivationsAndGradients
|
||||||
from pytorch_grad_cam.utils.image import show_cam_on_image
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
||||||
|
@ -28,15 +28,14 @@ except ImportError:
|
||||||
raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install '
|
raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install '
|
||||||
'3rd party package pytorch_grad_cam.')
|
'3rd party package pytorch_grad_cam.')
|
||||||
|
|
||||||
# Supported grad-cam type map
|
# Alias name
|
||||||
METHOD_MAP = {
|
METHOD_MAP = {
|
||||||
'gradcam': GradCAM,
|
'gradcam++': cam.GradCAMPlusPlus,
|
||||||
'gradcam++': GradCAMPlusPlus,
|
|
||||||
'xgradcam': XGradCAM,
|
|
||||||
'eigencam': EigenCAM,
|
|
||||||
'eigengradcam': EigenGradCAM,
|
|
||||||
'layercam': LayerCAM,
|
|
||||||
}
|
}
|
||||||
|
METHOD_MAP.update({
|
||||||
|
cam_class.__name__.lower(): cam_class
|
||||||
|
for cam_class in cam.base_cam.BaseCAM.__subclasses__()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -113,33 +112,21 @@ def parse_args():
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def build_reshape_transform(model, args):
|
def reshape_transform(tensor, model, args):
|
||||||
"""Build reshape_transform for `cam.activations_and_grads`, which is
|
"""Build reshape_transform for `cam.activations_and_grads`, which is
|
||||||
necessary for ViT-like networks."""
|
necessary for ViT-like networks."""
|
||||||
# ViT_based_Transformers have an additional clstoken in features
|
# ViT_based_Transformers have an additional clstoken in features
|
||||||
if not args.vit_like:
|
if tensor.ndim == 4:
|
||||||
|
# For (B, C, H, W)
|
||||||
|
return tensor
|
||||||
|
elif tensor.ndim == 3:
|
||||||
|
if not args.vit_like:
|
||||||
|
raise ValueError(f"The tensor shape is {tensor.shape}, if it's a "
|
||||||
|
'vit-like backbone, please specify `--vit-like`.')
|
||||||
|
# For (B, L, C)
|
||||||
|
num_extra_tokens = args.num_extra_tokens or getattr(
|
||||||
|
model.backbone, 'num_extra_tokens', 1)
|
||||||
|
|
||||||
def check_shape(tensor):
|
|
||||||
assert len(tensor.size()) != 3, \
|
|
||||||
(f"The input feature's shape is {tensor.size()}, and it seems "
|
|
||||||
'to have been flattened or from a vit-like network. '
|
|
||||||
"Please use `--vit-like` if it's from a vit-like network.")
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
return check_shape
|
|
||||||
|
|
||||||
if args.num_extra_tokens is not None:
|
|
||||||
num_extra_tokens = args.num_extra_tokens
|
|
||||||
elif hasattr(model.backbone, 'num_extra_tokens'):
|
|
||||||
num_extra_tokens = model.backbone.num_extra_tokens
|
|
||||||
else:
|
|
||||||
num_extra_tokens = 1
|
|
||||||
|
|
||||||
def _reshape_transform(tensor):
|
|
||||||
"""reshape_transform helper."""
|
|
||||||
assert len(tensor.size()) == 3, \
|
|
||||||
(f"The input feature's shape is {tensor.size()}, "
|
|
||||||
'and the feature seems not from a vit-like network?')
|
|
||||||
tensor = tensor[:, num_extra_tokens:, :]
|
tensor = tensor[:, num_extra_tokens:, :]
|
||||||
# get heat_map_height and heat_map_width, preset input is a square
|
# get heat_map_height and heat_map_width, preset input is a square
|
||||||
heat_map_area = tensor.size()[1]
|
heat_map_area = tensor.size()[1]
|
||||||
|
@ -149,13 +136,13 @@ def build_reshape_transform(model, args):
|
||||||
f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},'
|
f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},'
|
||||||
' which is not a perfect square number. Please check if you used '
|
' which is not a perfect square number. Please check if you used '
|
||||||
'a wrong num-extra-tokens.')
|
'a wrong num-extra-tokens.')
|
||||||
|
# (B, L, C) -> (B, H, W, C)
|
||||||
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
|
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
|
||||||
|
# (B, H, W, C) -> (B, C, H, W)
|
||||||
# Bring the channels to the first dimension, like in CNNs.
|
result = result.permute(0, 3, 1, 2)
|
||||||
result = result.transpose(2, 3).transpose(1, 2)
|
|
||||||
return result
|
return result
|
||||||
|
else:
|
||||||
return _reshape_transform
|
raise ValueError(f'Unsupported tensor shape {tensor.shape}.')
|
||||||
|
|
||||||
|
|
||||||
def init_cam(method, model, target_layers, use_cuda, reshape_transform):
|
def init_cam(method, model, target_layers, use_cuda, reshape_transform):
|
||||||
|
@ -175,37 +162,12 @@ def init_cam(method, model, target_layers, use_cuda, reshape_transform):
|
||||||
|
|
||||||
def get_layer(layer_str, model):
|
def get_layer(layer_str, model):
|
||||||
"""get model layer from given str."""
|
"""get model layer from given str."""
|
||||||
cur_layer = model
|
for name, layer in model.named_modules():
|
||||||
layer_names = layer_str.strip().split('.')
|
if name == layer_str:
|
||||||
|
return layer
|
||||||
def get_children_by_name(model, name):
|
raise AttributeError(
|
||||||
try:
|
f'Cannot get the layer "{layer_str}". Please choose from: \n' +
|
||||||
return getattr(model, name)
|
'\n'.join(name for name, _ in model.named_modules()))
|
||||||
except AttributeError as e:
|
|
||||||
raise AttributeError(
|
|
||||||
e.args[0] +
|
|
||||||
'. Please use `--preview-model` to check keys at first.')
|
|
||||||
|
|
||||||
def get_children_by_eval(model, name):
|
|
||||||
try:
|
|
||||||
return eval(f'model{name}', {}, {'model': model})
|
|
||||||
except (AttributeError, IndexError) as e:
|
|
||||||
raise AttributeError(
|
|
||||||
e.args[0] +
|
|
||||||
'. Please use `--preview-model` to check keys at first.')
|
|
||||||
|
|
||||||
for layer_name in layer_names:
|
|
||||||
match_res = re.match('(?P<name>.+?)(?P<indices>(\\[.+\\])+)',
|
|
||||||
layer_name)
|
|
||||||
if match_res:
|
|
||||||
layer_name = match_res.groupdict()['name']
|
|
||||||
indices = match_res.groupdict()['indices']
|
|
||||||
cur_layer = get_children_by_name(cur_layer, layer_name)
|
|
||||||
cur_layer = get_children_by_eval(cur_layer, indices)
|
|
||||||
else:
|
|
||||||
cur_layer = get_children_by_name(cur_layer, layer_name)
|
|
||||||
|
|
||||||
return cur_layer
|
|
||||||
|
|
||||||
|
|
||||||
def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
|
def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
|
||||||
|
@ -224,39 +186,32 @@ def show_cam_grad(grayscale_cam, src_img, title, out_path=None):
|
||||||
def get_default_traget_layers(model, args):
|
def get_default_traget_layers(model, args):
|
||||||
"""get default target layers from given model, here choose nrom type layer
|
"""get default target layers from given model, here choose nrom type layer
|
||||||
as default target layer."""
|
as default target layer."""
|
||||||
norm_layers = []
|
norm_layers = [
|
||||||
for m in model.backbone.modules():
|
(name, layer)
|
||||||
if isinstance(m, (BatchNorm2d, LayerNorm, GroupNorm, BatchNorm1d)):
|
for name, layer in model.backbone.named_modules(prefix='backbone')
|
||||||
norm_layers.append(m)
|
if is_norm(layer)
|
||||||
if len(norm_layers) == 0:
|
]
|
||||||
raise ValueError(
|
|
||||||
'`--target-layers` is empty. Please use `--preview-model`'
|
|
||||||
' to check keys at first and then specify `target-layers`.')
|
|
||||||
# if the model is CNN model or Swin model, just use the last norm
|
|
||||||
# layer as the target-layer, if the model is ViT model, the final
|
|
||||||
# classification is done on the class token computed in the last
|
|
||||||
# attention block, the output will not be affected by the 14x14
|
|
||||||
# channels in the last layer. The gradient of the output with
|
|
||||||
# respect to them, will be 0! here use the last 3rd norm layer.
|
|
||||||
# means the first norm of the last decoder block.
|
|
||||||
if args.vit_like:
|
if args.vit_like:
|
||||||
if args.num_extra_tokens:
|
# For ViT models, the final classification is done on the class token.
|
||||||
num_extra_tokens = args.num_extra_tokens
|
# And the patch tokens and class tokens won't interact each other after
|
||||||
elif hasattr(model.backbone, 'num_extra_tokens'):
|
# the final attention layer. Therefore, we need to choose the norm
|
||||||
num_extra_tokens = model.backbone.num_extra_tokens
|
# layer before the last attention layer.
|
||||||
else:
|
num_extra_tokens = args.num_extra_tokens or getattr(
|
||||||
raise AttributeError('Please set num_extra_tokens in backbone'
|
model.backbone, 'num_extra_tokens', 1)
|
||||||
" or using 'num-extra-tokens'")
|
|
||||||
|
|
||||||
# if a vit-like backbone's num_extra_tokens bigger than 0, view it
|
out_type = getattr(model.backbone, 'out_type')
|
||||||
# as a VisionTransformer backbone, eg. DeiT, T2T-ViT.
|
if out_type == 'cls_token' or num_extra_tokens > 0:
|
||||||
if num_extra_tokens >= 1:
|
# Assume the backbone feature is class token.
|
||||||
|
name, layer = norm_layers[-3]
|
||||||
print('Automatically choose the last norm layer before the '
|
print('Automatically choose the last norm layer before the '
|
||||||
'final attention block as target_layer..')
|
f'final attention block "{name}" as the target layer.')
|
||||||
return [norm_layers[-3]]
|
return [layer]
|
||||||
print('Automatically choose the last norm layer as target_layer.')
|
|
||||||
target_layers = [norm_layers[-1]]
|
# For CNN models, use the last norm layer as the target-layer
|
||||||
return target_layers
|
name, layer = norm_layers[-1]
|
||||||
|
print('Automatically choose the last norm layer '
|
||||||
|
f'"{name}" as the target layer.')
|
||||||
|
return [layer]
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -265,16 +220,16 @@ def main():
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
init_default_scope('mmpretrain')
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
model = init_model(cfg, args.checkpoint, device=args.device)
|
model: nn.Module = get_model(cfg, args.checkpoint, device=args.device)
|
||||||
if args.preview_model:
|
if args.preview_model:
|
||||||
print(model)
|
print(model)
|
||||||
print('\n Please remove `--preview-model` to get the CAM.')
|
print('\n Please remove `--preview-model` to get the CAM.')
|
||||||
return
|
return
|
||||||
|
|
||||||
# apply transform and perpare data
|
# apply transform and perpare data
|
||||||
transforms = Compose(cfg.test_dataloader.dataset.pipeline)
|
transforms = Compose(
|
||||||
|
[TRANSFORMS.build(t) for t in cfg.test_dataloader.dataset.pipeline])
|
||||||
data = transforms({'img_path': args.img})
|
data = transforms({'img_path': args.img})
|
||||||
src_img = copy.deepcopy(data['inputs']).numpy().transpose(1, 2, 0)
|
src_img = copy.deepcopy(data['inputs']).numpy().transpose(1, 2, 0)
|
||||||
data = model.data_preprocessor(default_collate([data]), False)
|
data = model.data_preprocessor(default_collate([data]), False)
|
||||||
|
@ -289,9 +244,8 @@ def main():
|
||||||
|
|
||||||
# init a cam grad calculator
|
# init a cam grad calculator
|
||||||
use_cuda = ('cuda' in args.device)
|
use_cuda = ('cuda' in args.device)
|
||||||
reshape_transform = build_reshape_transform(model, args)
|
|
||||||
cam = init_cam(args.method, model, target_layers, use_cuda,
|
cam = init_cam(args.method, model, target_layers, use_cuda,
|
||||||
reshape_transform)
|
partial(reshape_transform, model=model, args=args))
|
||||||
|
|
||||||
# warp the target_category with ClassifierOutputTarget in grad_cam>=1.3.7,
|
# warp the target_category with ClassifierOutputTarget in grad_cam>=1.3.7,
|
||||||
# to fix the bug in #654.
|
# to fix the bug in #654.
|
||||||
|
|
|
@ -108,7 +108,10 @@ def parse_args():
|
||||||
'WARNING.')
|
'WARNING.')
|
||||||
parser.add_argument('--title', type=str, help='title of figure')
|
parser.add_argument('--title', type=str, help='title of figure')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--style', type=str, default='whitegrid', help='style of plt')
|
'--style',
|
||||||
|
type=str,
|
||||||
|
default='whitegrid',
|
||||||
|
help='style of the figure, need `seaborn` package.')
|
||||||
parser.add_argument('--not-show', default=False, action='store_true')
|
parser.add_argument('--not-show', default=False, action='store_true')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window-size',
|
'--window-size',
|
||||||
|
|
Loading…
Reference in New Issue