mmrazor/tests/test_runners/test_quantization_loop.py

414 lines
14 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import shutil
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.config import Config, ConfigDict
from mmengine.evaluator import BaseMetric
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS
from mmengine.runner import Runner
from torch.nn.intrinsic.qat import ConvBnReLU2d
from torch.utils.data import Dataset
from mmrazor import digit_version
from mmrazor.engine import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop,
QATValLoop)
try:
from torch.ao.nn.quantized import FloatFunctional, FXFloatFunctional
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fake_quantize import FakeQuantizeBase
from torch.ao.quantization.fx import prepare
from torch.ao.quantization.qconfig_mapping import \
get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import _fuse_fx
except ImportError:
from mmrazor.utils import get_placeholder
QConfigMapping = get_placeholder('torch>=1.13')
FakeQuantizeBase = get_placeholder('torch>=1.13')
prepare = get_placeholder('torch>=1.13')
_fuse_fx = get_placeholder('torch>=1.13')
get_default_qconfig_mapping = get_placeholder('torch>=1.13')
FloatFunctional = get_placeholder('torch>=1.13')
FXFloatFunctional = get_placeholder('torch>=1.13')
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 3, 4, 4)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class MMArchitectureQuant(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
self.architecture = ToyModel()
def calibrate_step(self, data):
data = self.data_preprocessor(data, False)
return self.architecture(**data)
def sync_qparams(self, src_mode):
pass
def forward(self, inputs, data_sample, mode='tensor'):
return self.architecture(inputs, data_sample, mode)
class ToyModel(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
qconfig = get_default_qconfig_mapping().to_dict()['']
self.architecture = nn.Sequential(
ConvBnReLU2d(3, 3, 1, qconfig=qconfig))
def forward(self, inputs, data_sample, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_sample, list):
data_sample = torch.stack(data_sample)
outputs = self.architecture(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = data_sample.sum() - outputs.sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs
class ToyOptimWrapper(OptimWrapper):
...
class ToyMetric1(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_batch, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
DEFAULT_CFG = ConfigDict(
model=dict(type='MMArchitectureQuant'),
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(),
val_cfg=dict(),
test_cfg=dict(),
custom_hooks=[],
data_preprocessor=None,
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
)
class TestQATEpochBasedLoop(TestCase):
def setUp(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.temp_dir = tempfile.mkdtemp()
MODELS.register_module(module=MMArchitectureQuant, force=True)
DATASETS.register_module(module=ToyDataset, force=True)
METRICS.register_module(module=ToyMetric1, force=True)
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
default_cfg = copy.deepcopy(DEFAULT_CFG)
default_cfg = Config(default_cfg)
default_cfg.work_dir = self.temp_dir
default_cfg.train_cfg = ConfigDict(
type='mmrazor.QATEpochBasedLoop',
max_epochs=4,
val_begin=1,
val_interval=1,
disable_observer_begin=-1,
freeze_bn_begin=-1,
dynamic_intervals=None)
self.default_cfg = default_cfg
def tearDown(self):
MODELS.module_dict.pop('MMArchitectureQuant')
DATASETS.module_dict.pop('ToyDataset')
METRICS.module_dict.pop('ToyMetric1')
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(self.temp_dir)
def test_init(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_init_qat_train_loop'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
self.assertIsInstance(runner.train_loop, QATEpochBasedLoop)
def test_run_epoch(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_train'
runner = Runner.from_cfg(cfg)
runner.train()
@HOOKS.register_module(force=True)
class TestFreezeBNHook(Hook):
def __init__(self, freeze_bn_begin):
self.freeze_bn_begin = freeze_bn_begin
def after_train_epoch(self, runner):
def check_bn_stats(mod):
if isinstance(mod, ConvBnReLU2d):
assert mod.freeze_bn
assert not mod.bn.training
if runner.train_loop._epoch + 1 >= self.freeze_bn_begin:
runner.model.apply(check_bn_stats)
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_freeze_bn'
cfg.custom_hooks = [
dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1)
]
cfg.train_cfg.freeze_bn_begin = 1
runner = Runner.from_cfg(cfg)
runner.train()
@HOOKS.register_module(force=True)
class TestDisableObserverHook(Hook):
def __init__(self, disable_observer_begin):
self.disable_observer_begin = disable_observer_begin
def after_train_epoch(self, runner):
def check_observer_stats(mod):
if isinstance(mod, FakeQuantizeBase):
assert mod.fake_quant_enabled[0] == 0
if runner.train_loop._epoch + 1 >= self.disable_observer_begin:
runner.model.apply(check_observer_stats)
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_disable_observer'
cfg.custom_hooks = [
dict(
type='TestDisableObserverHook',
priority=50,
disable_observer_begin=1)
]
cfg.train_cfg.disable_observer_begin = 1
runner = Runner.from_cfg(cfg)
runner.train()
class TestLSQEpochBasedLoop(TestCase):
def setUp(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.temp_dir = tempfile.mkdtemp()
MODELS.register_module(module=MMArchitectureQuant, force=True)
DATASETS.register_module(module=ToyDataset, force=True)
METRICS.register_module(module=ToyMetric1, force=True)
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
default_cfg = copy.deepcopy(DEFAULT_CFG)
default_cfg = Config(default_cfg)
default_cfg.work_dir = self.temp_dir
default_cfg.train_cfg = ConfigDict(
type='mmrazor.LSQEpochBasedLoop',
max_epochs=4,
val_begin=1,
val_interval=1,
freeze_bn_begin=-1,
dynamic_intervals=None)
self.default_cfg = default_cfg
def tearDown(self):
MODELS.module_dict.pop('MMArchitectureQuant')
DATASETS.module_dict.pop('ToyDataset')
METRICS.module_dict.pop('ToyMetric1')
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(self.temp_dir)
def test_init(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_init_lsq_train_loop'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
self.assertIsInstance(runner.train_loop, LSQEpochBasedLoop)
def test_run_epoch(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_train'
runner = Runner.from_cfg(cfg)
runner.train()
@HOOKS.register_module(force=True)
class TestFreezeBNHook(Hook):
def __init__(self, freeze_bn_begin):
self.freeze_bn_begin = freeze_bn_begin
def after_train_epoch(self, runner):
def check_bn_stats(mod):
if isinstance(mod, ConvBnReLU2d):
assert mod.freeze_bn
assert not mod.bn.training
if runner.train_loop._epoch + 1 >= self.freeze_bn_begin:
runner.model.apply(check_bn_stats)
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_freeze_bn'
cfg.custom_hooks = [
dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1)
]
cfg.train_cfg.freeze_bn_begin = 1
runner = Runner.from_cfg(cfg)
runner.train()
class TestQATValLoop(TestCase):
def setUp(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.temp_dir = tempfile.mkdtemp()
MODELS.register_module(module=MMArchitectureQuant, force=True)
DATASETS.register_module(module=ToyDataset, force=True)
METRICS.register_module(module=ToyMetric1, force=True)
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
default_cfg = copy.deepcopy(DEFAULT_CFG)
default_cfg = Config(default_cfg)
default_cfg.work_dir = self.temp_dir
default_cfg.val_cfg = ConfigDict(type='mmrazor.QATValLoop')
self.default_cfg = default_cfg
def tearDown(self):
MODELS.module_dict.pop('MMArchitectureQuant')
DATASETS.module_dict.pop('ToyDataset')
METRICS.module_dict.pop('ToyMetric1')
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(self.temp_dir)
def test_init(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_init_qat_val_loop'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
self.assertIsInstance(runner.val_loop, QATValLoop)
def test_run(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_qat_val'
cfg.pop('train_dataloader')
cfg.pop('train_cfg')
cfg.pop('optim_wrapper')
cfg.pop('test_dataloader')
cfg.pop('test_cfg')
cfg.pop('test_evaluator')
runner = Runner.from_cfg(cfg)
runner.val()
class TestPTQLoop(TestCase):
def setUp(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.temp_dir = tempfile.mkdtemp()
MODELS.register_module(module=MMArchitectureQuant, force=True)
DATASETS.register_module(module=ToyDataset, force=True)
METRICS.register_module(module=ToyMetric1, force=True)
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
default_cfg = copy.deepcopy(DEFAULT_CFG)
default_cfg = Config(default_cfg)
default_cfg.work_dir = self.temp_dir
# save_checkpoint in PTQLoop need train_dataloader
default_cfg.train_cfg = ConfigDict(by_epoch=True, max_epochs=3)
default_cfg.test_cfg = ConfigDict(
type='mmrazor.PTQLoop',
calibrate_dataloader=default_cfg.train_dataloader,
calibrate_steps=32)
self.default_cfg = default_cfg
def tearDown(self):
MODELS.module_dict.pop('MMArchitectureQuant')
DATASETS.module_dict.pop('ToyDataset')
METRICS.module_dict.pop('ToyMetric1')
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(self.temp_dir)
def test_init(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_init_ptq_loop'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
self.assertIsInstance(runner.test_loop, PTQLoop)
def test_run(self):
cfg = copy.deepcopy(self.default_cfg)
cfg.experiment_name = 'test_ptq_run'
runner = Runner.from_cfg(cfg)
runner.test()