mmclassification/mmpretrain/apis/utils.py

271 lines
11 KiB
Python
Raw Normal View History

[Feature] Support multiple multi-modal algorithms and inferencers. (#1561) * [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import os
from collections import defaultdict
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmpretrain.utils import require
@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/')
@require('accelerate')
def dispatch_model(
model,
device_map: Union[str, dict],
max_memory: Optional[dict] = None,
no_split_module_classes: Optional[List[str]] = None,
offload_folder: str = None,
offload_buffers: bool = False,
preload_module_classes: Optional[List[str]] = None,
):
"""Split and dispatch a model across devices.
The function depends on the `accelerate` package. Refers to
https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling
Args:
model (torch.nn.Module): The model to dispatch.
device_map (str | dict | None): A map that specifies where each
submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every
submodule of it will be sent to the same device. You can use
`device_map="auto"` to automatically generate the device map.
Defaults to None.
max_memory (dict | None): A dictionary device identifier to maximum
memory. Will default to the maximum memory available for each GPU
and the available CPU RAM if unset. Defaults to None.
no_split_module_classes (List[str] | None): A list of layer class names
that should never be split across device (for instance any layer
that has a residual connection). If None, try to get the settings
from the model class. Defaults to None.
offload_folder (str | None): If the `device_map` contains any value
`"disk"`, the folder where we will offload weights.
offload_buffers (bool): In the layers that are offloaded on the CPU
or the hard drive, whether or not to offload the buffers as
well as the parameters. Defaults to False.
preload_module_classes (List[str] | None): A list of classes whose
instances should load all their weights (even in the submodules) at
the beginning of the forward. This should only be used for classes
that have submodules which are registered but not called directly
during the forward, for instance if a `dense` linear layer is
registered, but at forward, `dense.weight` and `dense.bias` are
used in some operations instead of calling `dense` directly.
Defaults to None.
"""
from accelerate import dispatch_model, infer_auto_device_map
# Check valid device_map string.
valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential']
if isinstance(device_map, str) and device_map not in valid_map_option:
raise ValueError('If passing a string for `device_map`, please choose '
f'from {valid_map_option}.')
# Generate device map automatically
if isinstance(device_map, str):
if no_split_module_classes is None:
no_split_module_classes = getattr(model, '_no_split_modules', None)
if no_split_module_classes is None:
raise ValueError(f'{model.__class__.__name__} does not support '
f"`device_map='{device_map}'` yet.")
if device_map != 'sequential':
from accelerate.utils import get_balanced_memory
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=None,
low_zero=(device_map == 'balanced_low_0'),
)
max_memory[0] *= 0.9
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=None,
)
if 'disk' in device_map.values():
if offload_folder is None:
raise ValueError(
'The current `device_map` had weights offloaded to the disk. '
'Please provide an `offload_folder` for them.')
os.makedirs(offload_folder, exist_ok=True)
main_device = next(
(d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu')
model = dispatch_model(
model,
device_map=device_map,
main_device=main_device,
offload_dir=offload_folder,
offload_buffers=offload_buffers,
preload_module_classes=preload_module_classes,
)
if hasattr(model, 'data_preprocessor'):
model.data_preprocessor._device = torch.device(main_device)
return model
@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""A context manager under which models are initialized with all parameters
on the meta device.
With this context manager, we can create an empty model. Useful when just
initializing the model would blow the available RAM.
Besides move the parameters to meta device, this method will also avoid
load checkpoint from `mmengine.runner.load_checkpoint` and
`transformers.PreTrainedModel.from_pretrained`.
Modified from https://github.com/huggingface/accelerate
Args:
include_buffers (bool): Whether put all buffers on the meta device
during initialization.
"""
device = torch.device('meta')
# move parameter and buffer to meta device
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
# See https://github.com/huggingface/accelerate/pull/699
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ['empty', 'zeros', 'ones', 'full']
}
def register_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs)
def register_buffer(module, name, buffer, *args, **kwargs):
old_register_buffer(module, name, buffer, *args, **kwargs)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs['device'] = device
return fn(*args, **kwargs)
return wrapper
# Patch load_checkpoint
import mmengine.runner.checkpoint as mmengine_load
old_load_checkpoint = mmengine_load.load_checkpoint
def patch_load_checkpoint(*args, **kwargs):
return {}
# Patch transformers from pretrained
try:
from transformers import PreTrainedModel
from transformers.models.auto.auto_factory import (AutoConfig,
_BaseAutoModelClass)
with_transformers = True
except ImportError:
with_transformers = False
@classmethod
def patch_auto_model(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path,
*model_args, **kwargs)
return cls.from_config(cfg)
@classmethod
def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path,
*model_args, **kwargs)
return cls(cfg)
if with_transformers:
old_pretrained_model = PreTrainedModel.from_pretrained
old_auto_model = _BaseAutoModelClass.from_pretrained
try:
nn.Module.register_parameter = register_parameter
mmengine_load.load_checkpoint = patch_load_checkpoint
if with_transformers:
PreTrainedModel.from_pretrained = patch_pretrained_model
_BaseAutoModelClass.from_pretrained = patch_auto_model
if include_buffers:
nn.Module.register_buffer = register_buffer
for func in tensor_constructors_to_patch.keys():
tensor_constructor = patch_tensor_constructor(
getattr(torch, func))
setattr(torch, func, tensor_constructor)
yield
finally:
nn.Module.register_parameter = old_register_parameter
mmengine_load.load_checkpoint = old_load_checkpoint
if with_transformers:
PreTrainedModel.from_pretrained = old_pretrained_model
_BaseAutoModelClass.from_pretrained = old_auto_model
if include_buffers:
nn.Module.register_buffer = old_register_buffer
for func, ori in tensor_constructors_to_patch.items():
setattr(torch, func, ori)
def compute_module_sizes(
model: nn.Module,
dtype: Union[str, torch.dtype, None] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None):
"""Compute the size of each submodule of a given model."""
def get_dtype(dtype):
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
if dtype is not None:
assert issubclass(dtype, torch.dtype)
return dtype
def dtype_bytes(dtype: torch.dtype):
if dtype is torch.bool:
return 1
if dtype.is_floating_point:
return torch.finfo(dtype).bits / 8
else:
return torch.iinfo(dtype).bits / 8
if dtype is not None:
dtype = get_dtype(dtype)
dtype_size = dtype_bytes(dtype)
if special_dtypes is not None:
special_dtypes = {
key: dtype_bytes(dtype)
for key, dtype in special_dtypes.items()
}
module_sizes = defaultdict(int)
for name, tensor in chain(
model.named_parameters(recurse=True),
model.named_buffers(recurse=True)):
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes[name]
elif dtype is None:
size = tensor.numel() * tensor.element_size()
else:
size = tensor.numel() * min(dtype_size, tensor.element_size())
name_parts = name.split('.')
for idx in range(len(name_parts) + 1):
module_sizes['.'.join(name_parts[:idx])] += size
return module_sizes