267 lines
6.0 KiB
Markdown
267 lines
6.0 KiB
Markdown
|
# Tutorial 3: Customize Models
|
||
|
|
||
|
### Add a new classifier
|
||
|
|
||
|
Here we show how to develop a new classifier with an example as follows
|
||
|
|
||
|
#### 1. Define a new classifier
|
||
|
|
||
|
Create a new file `mmfewshot/classification/models/classifiers/my_classifier.py`.
|
||
|
|
||
|
```python
|
||
|
from mmcls.models.builder import CLASSIFIERS
|
||
|
from .base import BaseFewShotClassifier
|
||
|
|
||
|
@CLASSIFIERS.register_module()
|
||
|
class MyClassifier(BaseFewShotClassifier):
|
||
|
|
||
|
def __init__(self, arg1, arg2):
|
||
|
pass
|
||
|
|
||
|
# customize input for different mode
|
||
|
# the input should keep consistent with the dataset
|
||
|
def forward(self, img, mode='train',**kwargs):
|
||
|
if mode == 'train':
|
||
|
return self.forward_train(img=img, **kwargs)
|
||
|
elif mode == 'query':
|
||
|
return self.forward_query(img=img, **kwargs)
|
||
|
elif mode == 'support':
|
||
|
return self.forward_support(img=img, **kwargs)
|
||
|
elif mode == 'extract_feat':
|
||
|
assert img is not None
|
||
|
return self.extract_feat(img=img)
|
||
|
else:
|
||
|
raise ValueError()
|
||
|
|
||
|
# customize forward function for training data
|
||
|
def forward_train(self, img, gt_label, **kwargs):
|
||
|
pass
|
||
|
|
||
|
# customize forward function for meta testing support data
|
||
|
def forward_support(self, img, gt_label, **kwargs):
|
||
|
pass
|
||
|
|
||
|
# customize forward function for meta testing query data
|
||
|
def forward_query(self, img):
|
||
|
pass
|
||
|
|
||
|
# prepare meta testing
|
||
|
def before_meta_test(self, meta_test_cfg, **kwargs):
|
||
|
pass
|
||
|
|
||
|
# prepare forward meta testing query images
|
||
|
def before_forward_support(self, **kwargs):
|
||
|
pass
|
||
|
|
||
|
# prepare forward meta testing support images
|
||
|
def before_forward_query(self, **kwargs):
|
||
|
pass
|
||
|
```
|
||
|
|
||
|
#### 2. Import the module
|
||
|
|
||
|
You can either add the following line to `mmfewshot/classification/models/heads/__init__.py`
|
||
|
|
||
|
```python
|
||
|
from .my_classifier import MyClassifier
|
||
|
```
|
||
|
|
||
|
or alternatively add
|
||
|
|
||
|
```python
|
||
|
custom_imports = dict(
|
||
|
imports=['mmfewshot.classification.models.classifier.my_classifier'],
|
||
|
allow_failed_imports=False)
|
||
|
```
|
||
|
|
||
|
to the config file to avoid modifying the original code.
|
||
|
|
||
|
#### 3. Use the classifier in your config file
|
||
|
|
||
|
```python
|
||
|
model = dict(
|
||
|
type="MyClassifier",
|
||
|
...
|
||
|
)
|
||
|
```
|
||
|
|
||
|
### Add a new backbone
|
||
|
|
||
|
Here we show how to develop a new backbone with an example as follows
|
||
|
|
||
|
#### 1. Define a new backbone
|
||
|
|
||
|
Create a new file `mmfewshot/classification/models/backbones/mynet.py`.
|
||
|
|
||
|
```python
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from mmcls.models.builder import BACKBONES
|
||
|
|
||
|
@BACKBONES.register_module()
|
||
|
class MyNet(nn.Module):
|
||
|
|
||
|
def __init__(self, arg1, arg2):
|
||
|
pass
|
||
|
|
||
|
def forward(self, x): # should return a tensor
|
||
|
pass
|
||
|
```
|
||
|
|
||
|
#### 2. Import the module
|
||
|
|
||
|
You can either add the following line to `mmfewshot/classification/models/backbones/__init__.py`
|
||
|
|
||
|
```python
|
||
|
from .mynet import MyNet
|
||
|
```
|
||
|
|
||
|
or alternatively add
|
||
|
|
||
|
```python
|
||
|
custom_imports = dict(
|
||
|
imports=['mmfewshot.classification.models.backbones.mynet'],
|
||
|
allow_failed_imports=False)
|
||
|
```
|
||
|
|
||
|
to the config file to avoid modifying the original code.
|
||
|
|
||
|
#### 3. Use the backbone in your config file
|
||
|
|
||
|
```python
|
||
|
model = dict(
|
||
|
...
|
||
|
backbone=dict(
|
||
|
type='MyNet',
|
||
|
arg1=xxx,
|
||
|
arg2=xxx),
|
||
|
...
|
||
|
```
|
||
|
|
||
|
### Add new heads
|
||
|
|
||
|
Here we show how to develop a new head with an example as follows
|
||
|
|
||
|
#### 1. Define a new head
|
||
|
|
||
|
Create a new file `mmfewshot/classification/models/heads/myhead.py`.
|
||
|
|
||
|
```python
|
||
|
from mmcls.models.builder import HEADS
|
||
|
from .base_head import BaseFewShotHead
|
||
|
|
||
|
@HEADS.register_module()
|
||
|
class MyHead(BaseFewShotHead):
|
||
|
|
||
|
def __init__(self, arg1, arg2) -> None:
|
||
|
pass
|
||
|
|
||
|
def forward_train(self, x, gt_label, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def forward_support(self, x, gt_label, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def forward_query(self, x, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def before_forward_support(self) -> None:
|
||
|
pass
|
||
|
|
||
|
def before_forward_query(self) -> None:
|
||
|
pass
|
||
|
```
|
||
|
|
||
|
#### 2. Import the module
|
||
|
|
||
|
You can either add the following line to `mmfewshot/classification/models/heads/__init__.py`
|
||
|
|
||
|
```python
|
||
|
from .myhead import MyHead
|
||
|
```
|
||
|
|
||
|
or alternatively add
|
||
|
|
||
|
```python
|
||
|
custom_imports = dict(
|
||
|
imports=['mmfewshot.classification.models.backbones.myhead'],
|
||
|
allow_failed_imports=False)
|
||
|
```
|
||
|
|
||
|
to the config file to avoid modifying the original code.
|
||
|
|
||
|
#### 3. Use the head in your config file
|
||
|
|
||
|
```python
|
||
|
model = dict(
|
||
|
...
|
||
|
head=dict(
|
||
|
type='MyHead',
|
||
|
arg1=xxx,
|
||
|
arg2=xxx),
|
||
|
...
|
||
|
```
|
||
|
|
||
|
### Add new loss
|
||
|
|
||
|
|
||
|
To add a new loss function, the users need implement it in `mmfewshot/classification/models/losses/my_loss.py`.
|
||
|
The decorator `weighted_loss` enable the loss to be weighted for each element.
|
||
|
|
||
|
```python
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from ..builder import LOSSES
|
||
|
from .utils import weighted_loss
|
||
|
|
||
|
@weighted_loss
|
||
|
def my_loss(pred, target):
|
||
|
assert pred.size() == target.size() and target.numel() > 0
|
||
|
loss = torch.abs(pred - target)
|
||
|
return loss
|
||
|
|
||
|
@LOSSES.register_module()
|
||
|
class MyLoss(nn.Module):
|
||
|
|
||
|
def __init__(self, reduction='mean', loss_weight=1.0):
|
||
|
super(MyLoss, self).__init__()
|
||
|
self.reduction = reduction
|
||
|
self.loss_weight = loss_weight
|
||
|
|
||
|
def forward(self,
|
||
|
pred,
|
||
|
target,
|
||
|
weight=None,
|
||
|
avg_factor=None,
|
||
|
reduction_override=None):
|
||
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||
|
reduction = (
|
||
|
reduction_override if reduction_override else self.reduction)
|
||
|
loss_bbox = self.loss_weight * my_loss(
|
||
|
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
|
||
|
return loss_bbox
|
||
|
```
|
||
|
|
||
|
Then the users need to add it in the `mmfewshot/classification/models/losses/__init__.py`.
|
||
|
|
||
|
```python
|
||
|
from .my_loss import MyLoss, my_loss
|
||
|
```
|
||
|
|
||
|
Alternatively, you can add
|
||
|
|
||
|
```python
|
||
|
custom_imports=dict(
|
||
|
imports=['mmfewshot.classification.models.losses.my_loss'])
|
||
|
```
|
||
|
|
||
|
to the config file and achieve the same goal.
|
||
|
|
||
|
To use it, modify the `loss_xxx` field.
|
||
|
Since MyLoss is for regression, you need to modify the `loss_bbox` field in the head.
|
||
|
|
||
|
```python
|
||
|
loss_bbox=dict(type='MyLoss', loss_weight=1.0))
|
||
|
```
|