Merge branch 'main' into dev
commit
58a2243d99
|
@ -21,7 +21,7 @@ Instruction tuning large language models (LLMs) using machine-generated instruct
|
|||
According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
|
||||
script to download and get the merged the checkpoint.
|
||||
|
||||
```baseh
|
||||
```shell
|
||||
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
|
||||
```
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
- Support inference of more **multi-modal** algorithms, such as **LLaVA**, **MiniGPT-4**, **Otter**, etc.
|
||||
- Support around **10 multi-modal datasets**!
|
||||
- Add **iTPN**, **SparK** self-supervised learning algorithms.
|
||||
- Provide examples of [New Config](./mmpretrain/configs/) and [DeepSpeed/FSDP](./configs/mae/benchmarks/).
|
||||
- Provide examples of [New Config](https://github.com/open-mmlab/mmpretrain/tree/main/mmpretrain/configs/) and [DeepSpeed/FSDP](https://github.com/open-mmlab/mmpretrain/tree/main/configs/mae/benchmarks/).
|
||||
|
||||
### New Features
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
## Shape Bias Tool Usage
|
||||
# Shape Bias Tool Usage
|
||||
|
||||
Shape bias measures how a model relies the shapes, compared to texture, to sense the semantics in images. For more details,
|
||||
we recommend interested readers to this [paper](https://arxiv.org/abs/2106.07411). MMPretrain provide an off-the-shelf toolbox to
|
||||
obtain the shape bias of a classification model. You can following these steps below:
|
||||
|
||||
### Prepare the dataset
|
||||
## Prepare the dataset
|
||||
|
||||
First you should download the [cue-conflict](https://github.com/bethgelab/model-vs-human/releases/download/v0.1/cue-conflict.tar.gz) to `data` folder,
|
||||
and then unzip this dataset. After that, you `data` folder should have the following structure:
|
||||
|
@ -18,7 +18,7 @@ data
|
|||
| |── truck
|
||||
```
|
||||
|
||||
### Modify the config for classification
|
||||
## Modify the config for classification
|
||||
|
||||
We run the shape-bias tool on a ViT-base model with masked autoencoder pretraining. Its config file is `configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py`, and its checkpoint is downloaded from [this link](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth). Replace the original test_pipeline, test_dataloader and test_evaluation with the following configurations:
|
||||
|
||||
|
@ -55,7 +55,7 @@ test_evaluator = dict(
|
|||
|
||||
Please note you should make custom modifications to the `csv_dir` and `model_name` above. I renamed my modified sample config file as `vit-base-p16_8xb128-coslr-100e_in1k_shape-bias.py` in the folder `configs/mae/benchmarks/`.
|
||||
|
||||
### Inference your model with above modified config file
|
||||
## Inference your model with above modified config file
|
||||
|
||||
Then you should inferece your model on the `cue-conflict` dataset with the your modified config file.
|
||||
|
||||
|
@ -77,7 +77,7 @@ bash tools/dist_test.sh configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in
|
|||
After that, you should obtain a csv file in `csv_dir` folder, named `cue-conflict_model-name_session-1.csv`. Besides this file, you should also download these [csv files](https://github.com/bethgelab/model-vs-human/tree/master/raw-data/cue-conflict) to the
|
||||
`csv_dir`.
|
||||
|
||||
### Plot shape bias
|
||||
## Plot shape bias
|
||||
|
||||
Then we can start to plot the shape bias:
|
||||
|
||||
|
|
|
@ -1169,7 +1169,7 @@ class GaussianBlur(BaseAugTransform):
|
|||
|
||||
img = results['img']
|
||||
pil_img = Image.fromarray(img)
|
||||
pil_img.filter(ImageFilter.GaussianBlur(radius=radius))
|
||||
pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=radius))
|
||||
results['img'] = np.array(pil_img, dtype=img.dtype)
|
||||
|
||||
return results
|
||||
|
|
|
@ -3,6 +3,7 @@ import os.path as osp
|
|||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import get_rank, get_world_size, is_distributed
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import MMLogger
|
||||
|
@ -97,11 +98,13 @@ class SwAVHook(Hook):
|
|||
if self.queue_length > 0 \
|
||||
and runner.epoch >= self.epoch_queue_starts \
|
||||
and self.queue is None:
|
||||
|
||||
self.queue = torch.zeros(
|
||||
len(self.crops_for_assign),
|
||||
self.queue_length // runner.world_size,
|
||||
self.feat_dim,
|
||||
).cuda()
|
||||
device=get_device(),
|
||||
)
|
||||
|
||||
# set the boolean type of use_the_queue
|
||||
get_ori_model(runner.model).head.loss_module.queue = self.queue
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import List, Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
@ -43,8 +44,7 @@ class iTPNClipHead(BaseModule):
|
|||
target (torch.Tensor): Target generated by target_generator.
|
||||
mask (torch.Tensor): Generated mask for pretraing.
|
||||
"""
|
||||
|
||||
mask = mask.to(torch.device('cuda'), non_blocking=True)
|
||||
mask = mask.to(get_device(), non_blocking=True)
|
||||
mask = mask.flatten(1).to(torch.bool)
|
||||
target = target[mask]
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class Flamingo(BaseModel):
|
|||
zeroshot_prompt (str): Prompt used for zero-shot inference.
|
||||
Defaults to '<image>Output:'.
|
||||
shot_prompt_tmpl (str): Prompt used for few-shot inference.
|
||||
Defaults to '<image>Output:{caption}<|endofchunk|>'.
|
||||
Defaults to ``<image>Output:{caption}<|endofchunk|>``.
|
||||
final_prompt_tmpl (str): Final part of prompt used for inference.
|
||||
Defaults to '<image>Output:'.
|
||||
generation_cfg (dict): The extra generation config, accept the keyword
|
||||
|
|
|
@ -36,7 +36,7 @@ class MiniGPT4(BaseModel):
|
|||
raw_prompts (list): Prompts for training. Defaults to None.
|
||||
max_txt_len (int): Max token length while doing tokenization. Defaults
|
||||
to 32.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '\n'.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '\\n'.
|
||||
generation_cfg (dict): The config of text generation. Defaults to
|
||||
dict().
|
||||
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
|
||||
|
|
|
@ -20,8 +20,8 @@ class Otter(Flamingo):
|
|||
zeroshot_prompt (str): Prompt used for zero-shot inference.
|
||||
Defaults to an.
|
||||
shot_prompt_tmpl (str): Prompt used for few-shot inference.
|
||||
Defaults to '<image>User:Please describe the image.
|
||||
GPT:<answer>{caption}<|endofchunk|>'.
|
||||
Defaults to ``<image>User:Please describe the image.
|
||||
GPT:<answer>{caption}<|endofchunk|>``.
|
||||
final_prompt_tmpl (str): Final part of prompt used for inference.
|
||||
Defaults to '<image>User:Please describe the image. GPT:<answer>'.
|
||||
generation_cfg (dict): The extra generation config, accept the keyword
|
||||
|
|
|
@ -20,25 +20,26 @@ mmpretrain.utils.progress.disable_progress_bar = True
|
|||
|
||||
logger = MMLogger('mmpretrain', logger_name='mmpre')
|
||||
if torch.cuda.is_available():
|
||||
gpus = [
|
||||
devices = [
|
||||
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
||||
]
|
||||
logger.info(f'Available GPUs: {len(gpus)}')
|
||||
logger.info(f'Available GPUs: {len(devices)}')
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
devices = [torch.device('mps')]
|
||||
logger.info('Available MPS.')
|
||||
else:
|
||||
gpus = None
|
||||
logger.info('No available GPU.')
|
||||
devices = [torch.device('cpu')]
|
||||
logger.info('Available CPU.')
|
||||
|
||||
|
||||
def get_free_device():
|
||||
if gpus is None:
|
||||
return torch.device('cpu')
|
||||
if hasattr(torch.cuda, 'mem_get_info'):
|
||||
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in gpus]
|
||||
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
|
||||
select = max(zip(free, range(len(free))))[1]
|
||||
else:
|
||||
import random
|
||||
select = random.randint(0, len(gpus) - 1)
|
||||
return gpus[select]
|
||||
select = random.randint(0, len(devices) - 1)
|
||||
return devices[select]
|
||||
|
||||
|
||||
class InferencerCache:
|
||||
|
|
|
@ -1285,9 +1285,10 @@ class TestGaussianBlur(TestCase):
|
|||
|
||||
def test_transform(self):
|
||||
transform_func = 'PIL.ImageFilter.GaussianBlur'
|
||||
from PIL.ImageFilter import GaussianBlur
|
||||
|
||||
# test params inputs
|
||||
with patch(transform_func, autospec=True) as mock:
|
||||
with patch(transform_func, wraps=GaussianBlur) as mock:
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS,
|
||||
'radius': 0.5,
|
||||
|
@ -1297,7 +1298,7 @@ class TestGaussianBlur(TestCase):
|
|||
mock.assert_called_once_with(radius=0.5)
|
||||
|
||||
# test prob
|
||||
with patch(transform_func, autospec=True) as mock:
|
||||
with patch(transform_func, wraps=GaussianBlur) as mock:
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS,
|
||||
'radius': 0.5,
|
||||
|
@ -1307,7 +1308,7 @@ class TestGaussianBlur(TestCase):
|
|||
mock.assert_not_called()
|
||||
|
||||
# test magnitude_range
|
||||
with patch(transform_func, autospec=True) as mock:
|
||||
with patch(transform_func, wraps=GaussianBlur) as mock:
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS,
|
||||
'magnitude_range': (0.1, 2),
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.optim import OptimWrapper
|
||||
|
@ -79,7 +80,7 @@ class TestDenseCLHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_densecl_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
dummy_dataset = DummyDataset()
|
||||
toy_model = ToyModel().to(device)
|
||||
densecl_hook = DenseCLHook(start_iters=1)
|
||||
|
|
|
@ -8,6 +8,7 @@ from unittest.mock import ANY, MagicMock, call
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel
|
||||
|
@ -70,7 +71,7 @@ class TestEMAHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_load_state_dict(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
ema_hook = EMAHook()
|
||||
runner = Runner(
|
||||
|
@ -95,7 +96,7 @@ class TestEMAHook(TestCase):
|
|||
|
||||
def test_evaluate_on_ema(self):
|
||||
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
# Test validate on ema model
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import Runner
|
||||
|
@ -79,7 +80,7 @@ class TestSimSiamHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_simsiam_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
dummy_dataset = DummyDataset()
|
||||
toy_model = ToyModel().to(device)
|
||||
simsiam_hook = SimSiamHook(
|
||||
|
|
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.optim import OptimWrapper
|
||||
|
@ -86,7 +87,7 @@ class TestSwAVHook(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_swav_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
dummy_dataset = DummyDataset()
|
||||
toy_model = ToyModel().to(device)
|
||||
swav_hook = SwAVHook(
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.dataset import BaseDataset, ConcatDataset, RepeatDataset
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel
|
||||
from mmengine.optim import OptimWrapper
|
||||
|
@ -130,7 +131,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
self.assertIsNone(hook.schedule[1]['batch_augments'])
|
||||
|
||||
def test_do_switch(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
loss = CrossEntropyLoss(use_soft=True)
|
||||
|
@ -205,7 +206,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
# runner.train()
|
||||
|
||||
def test_resume(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
loss = CrossEntropyLoss(use_soft=True)
|
||||
|
@ -275,7 +276,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
logs.output)
|
||||
|
||||
def test_switch_train_pipeline(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
runner = Runner(
|
||||
|
@ -324,7 +325,7 @@ class TestSwitchRecipeHook(TestCase):
|
|||
pipeline)
|
||||
|
||||
def test_switch_loss(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
device = get_device()
|
||||
model = SimpleModel().to(device)
|
||||
|
||||
runner = Runner(
|
||||
|
|
Loading…
Reference in New Issue