mmclassification/mmpretrain/models/classifiers/hugging_face.py

223 lines
8.7 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All right reserved.
import re
from collections import OrderedDict
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
[Feature] Support multiple multi-modal algorithms and inferencers. (#1561) * [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00
from mmpretrain.utils import require
from .base import BaseClassifier
@MODELS.register_module()
class HuggingFaceClassifier(BaseClassifier):
"""Image classifiers for HuggingFace model.
This class accepts all positional and keyword arguments of the API
``from_pretrained`` (when ``pretrained=True``) and ``from_config`` (when
``pretrained=False``) of `transformers.AutoModelForImageClassification`_
and use it to create a model from hugging-face.
It can load checkpoints of hugging-face directly, and the saved checkpoints
also can be directly load by hugging-face.
Please confirm that you have installed ``transfromers`` if you want to use it.
.. _transformers.AutoModelForImageClassification:
https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification
Args:
model_name (str): The name of the model to use in hugging-face.
pretrained (bool): Whether to load pretrained checkpoint from
hugging-face. Defaults to False.
*args: Other positional arguments of the method
`from_pretrained` or `from_config`.
loss (dict): Config of classification loss. Defaults to
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
train_cfg (dict, optional): The training setting. The acceptable
fields are:
- augments (List[dict]): The batch augmentation methods to use.
More details can be found in :mod:`mmpretrain.model.utils.augment`.
Defaults to None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
data_preprocessor (dict, optional): The config for preprocessing input
data. If None or no specified type, it will use
"ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
more details. Defaults to None.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
**kwargs: Other keyword arguments of the method
`from_pretrained` or `from_config`.
Examples:
>>> import torch
>>> from mmpretrain.models import build_classifier
>>> cfg = dict(type='HuggingFaceClassifier', model_name='microsoft/resnet-50', pretrained=True)
>>> model = build_classifier(cfg)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> out = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
""" # noqa: E501
[Feature] Support multiple multi-modal algorithms and inferencers. (#1561) * [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00
@require('transformers')
def __init__(self,
model_name,
pretrained=False,
*model_args,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
train_cfg: Optional[dict] = None,
with_cp: bool = False,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None,
**kwargs):
if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')
if train_cfg is not None and 'augments' in train_cfg:
# Set batch augmentations by `train_cfg`
data_preprocessor['batch_augments'] = train_cfg
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
from transformers import AutoConfig, AutoModelForImageClassification
if pretrained:
self.model = AutoModelForImageClassification.from_pretrained(
model_name, *model_args, **kwargs)
else:
config = AutoConfig.from_pretrained(model_name, *model_args,
**kwargs)
self.model = AutoModelForImageClassification.from_config(config)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
self.with_cp = with_cp
if self.with_cp:
self.model.gradient_checkpointing_enable()
self._register_state_dict_hook(self._remove_state_dict_prefix)
self._register_load_state_dict_pre_hook(self._add_state_dict_prefix)
def forward(self, inputs, data_samples=None, mode='tensor'):
if mode == 'tensor':
return self.model(inputs).logits
elif mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_feat(self, inputs: torch.Tensor):
raise NotImplementedError(
"The HuggingFaceClassifier doesn't support extract feature yet.")
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments of the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# The part can be traced by torch.fx
cls_score = self.model(inputs).logits
# The part can not be traced by torch.fx
losses = self._get_loss(cls_score, data_samples, **kwargs)
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_score for i in data_samples])
else:
target = torch.cat([i.gt_label for i in data_samples])
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
return losses
def predict(self,
inputs: torch.Tensor,
data_samples: Optional[List[DataSample]] = None):
"""Predict results from a batch of inputs.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
Returns:
List[DataSample]: The prediction results.
"""
# The part can be traced by torch.fx
cls_score = self.model(inputs).logits
# The part can not be traced by torch.fx
predictions = self._get_predictions(cls_score, data_samples)
return predictions
def _get_predictions(self, cls_score, data_samples):
"""Post-process the output of head.
Including softmax and set ``pred_label`` of data samples.
"""
pred_scores = F.softmax(cls_score, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(
DataSample().set_pred_score(score).set_pred_label(label))
return data_samples
@staticmethod
def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = re.sub(f'^{prefix}model.', prefix, k)
new_state_dict[new_key] = v
return new_state_dict
@staticmethod
def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
new_prefix = prefix + 'model.'
for k in list(state_dict.keys()):
new_key = re.sub(f'^{prefix}', new_prefix, k)
state_dict[new_key] = state_dict[k]
del state_dict[k]