[Improve] Update Otter and LLaVA docs and config. (#1653)
parent
dbef2b41c6
commit
7d850dfadd
|
@ -34,6 +34,7 @@ from mmpretrain import get_model, inference_model
|
|||
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
|
||||
out = inference_model(model, 'demo/cat-dog.png')
|
||||
print(out)
|
||||
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
|
||||
```
|
||||
|
||||
**Test Command**
|
||||
|
|
|
@ -22,9 +22,10 @@ Large language models (LLMs) have demonstrated significant universal capabilitie
|
|||
import torch
|
||||
from mmpretrain import get_model, inference_model
|
||||
|
||||
model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda')
|
||||
model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda', generation_cfg=dict(max_new_tokens=50))
|
||||
out = inference_model(model, 'demo/cat-dog.png')
|
||||
print(out)
|
||||
# {'pred_caption': 'The image features two adorable small puppies sitting next to each other on the grass. One puppy is brown and white, while the other is tan and white. They appear to be relaxing outdoors, enjoying each other'}
|
||||
```
|
||||
|
||||
**Test Command**
|
||||
|
@ -43,17 +44,17 @@ python tools/test.py configs/otter/otter-9b_caption.py https://download.openmmla
|
|||
|
||||
### Image Caption on COCO
|
||||
|
||||
| Model | Pretrain | Params (M) | BLEU-4 | CIDER | Config | Download |
|
||||
| :---------------------------- | :----------: | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
|
||||
| `otter-9b_3rdparty_caption`\* | From scratch | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
|
||||
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
|
||||
| :---------------------------- | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
|
||||
| `otter-9b_3rdparty_caption`\* | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
### Visual Question Answering on VQAv2
|
||||
|
||||
| Model | Pretrain | Params (M) | Accuracy | Config | Download |
|
||||
| :------------------------ | :----------: | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
|
||||
| `otter-9b_3rdparty_vqa`\* | From scratch | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
|
||||
| Model | Params (M) | Accuracy | Config | Download |
|
||||
| :------------------------ | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
|
||||
| `otter-9b_3rdparty_vqa`\* | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
|
|
|
@ -65,14 +65,10 @@ val_dataloader = dict(
|
|||
batch_size=8,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='FlamingoEvalCOCOCaption',
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/captions_train2014.json',
|
||||
data_prefix=dict(img_path='train2014'),
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
num_shots=0,
|
||||
num_support_examples=2048,
|
||||
num_query_examples=5000,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
|
@ -80,7 +76,7 @@ val_dataloader = dict(
|
|||
|
||||
val_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/captions_train2014.json')
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json')
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
|
|
|
@ -129,8 +129,9 @@ class Llava(BaseModel):
|
|||
mode: str = 'loss',
|
||||
):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
The method should accept only one mode "loss":
|
||||
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`DataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
|
@ -150,10 +151,10 @@ class Llava(BaseModel):
|
|||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
|
||||
if mode == 'loss':
|
||||
return self.loss(images, data_samples)
|
||||
elif mode == 'predict':
|
||||
if mode == 'predict':
|
||||
return self.predict(images, data_samples)
|
||||
elif mode == 'loss':
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
|
|
|
@ -10,13 +10,15 @@ from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler
|
|||
|
||||
@MODELS.register_module()
|
||||
class Otter(Flamingo):
|
||||
"""The Open Flamingo model for multiple tasks.
|
||||
"""The Otter model for multiple tasks.
|
||||
|
||||
Args:
|
||||
vision_encoder (dict): The config of the vision encoder.
|
||||
lang_encoder (dict): The config of the language encoder.
|
||||
tokenizer (dict): The tokenizer to encode the text.
|
||||
task (int): The task to perform prediction.
|
||||
zeroshot_prompt (str): Prompt used for zero-shot inference.
|
||||
Defaults to an.
|
||||
shot_prompt_tmpl (str): Prompt used for few-shot inference.
|
||||
Defaults to '<image>User:Please describe the image.
|
||||
GPT:<answer>{caption}<|endofchunk|>'.
|
||||
|
@ -69,7 +71,7 @@ class Otter(Flamingo):
|
|||
|
||||
# init tokenizer
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
# add Flamingo special tokens to the tokenizer
|
||||
# add Otter special tokens to the tokenizer
|
||||
self.tokenizer.add_special_tokens({
|
||||
'additional_special_tokens':
|
||||
['<|endofchunk|>', '<image>', '<answer>']
|
||||
|
|
Loading…
Reference in New Issue