mmfewshot/docs/classification/customize_models.md

267 lines
6.0 KiB
Markdown
Raw Normal View History

# Tutorial 3: Customize Models
### Add a new classifier
Here we show how to develop a new classifier with an example as follows
#### 1. Define a new classifier
Create a new file `mmfewshot/classification/models/classifiers/my_classifier.py`.
```python
from mmcls.models.builder import CLASSIFIERS
from .base import BaseFewShotClassifier
@CLASSIFIERS.register_module()
class MyClassifier(BaseFewShotClassifier):
def __init__(self, arg1, arg2):
pass
# customize input for different mode
# the input should keep consistent with the dataset
def forward(self, img, mode='train',**kwargs):
if mode == 'train':
return self.forward_train(img=img, **kwargs)
elif mode == 'query':
return self.forward_query(img=img, **kwargs)
elif mode == 'support':
return self.forward_support(img=img, **kwargs)
elif mode == 'extract_feat':
assert img is not None
return self.extract_feat(img=img)
else:
raise ValueError()
# customize forward function for training data
def forward_train(self, img, gt_label, **kwargs):
pass
# customize forward function for meta testing support data
def forward_support(self, img, gt_label, **kwargs):
pass
# customize forward function for meta testing query data
def forward_query(self, img):
pass
# prepare meta testing
def before_meta_test(self, meta_test_cfg, **kwargs):
pass
# prepare forward meta testing query images
def before_forward_support(self, **kwargs):
pass
# prepare forward meta testing support images
def before_forward_query(self, **kwargs):
pass
```
#### 2. Import the module
You can either add the following line to `mmfewshot/classification/models/heads/__init__.py`
```python
from .my_classifier import MyClassifier
```
or alternatively add
```python
custom_imports = dict(
imports=['mmfewshot.classification.models.classifier.my_classifier'],
allow_failed_imports=False)
```
to the config file to avoid modifying the original code.
#### 3. Use the classifier in your config file
```python
model = dict(
type="MyClassifier",
...
)
```
### Add a new backbone
Here we show how to develop a new backbone with an example as follows
#### 1. Define a new backbone
Create a new file `mmfewshot/classification/models/backbones/mynet.py`.
```python
import torch.nn as nn
from mmcls.models.builder import BACKBONES
@BACKBONES.register_module()
class MyNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tensor
pass
```
#### 2. Import the module
You can either add the following line to `mmfewshot/classification/models/backbones/__init__.py`
```python
from .mynet import MyNet
```
or alternatively add
```python
custom_imports = dict(
imports=['mmfewshot.classification.models.backbones.mynet'],
allow_failed_imports=False)
```
to the config file to avoid modifying the original code.
#### 3. Use the backbone in your config file
```python
model = dict(
...
backbone=dict(
type='MyNet',
arg1=xxx,
arg2=xxx),
...
```
### Add new heads
Here we show how to develop a new head with an example as follows
#### 1. Define a new head
Create a new file `mmfewshot/classification/models/heads/myhead.py`.
```python
from mmcls.models.builder import HEADS
from .base_head import BaseFewShotHead
@HEADS.register_module()
class MyHead(BaseFewShotHead):
def __init__(self, arg1, arg2) -> None:
pass
def forward_train(self, x, gt_label, **kwargs):
pass
def forward_support(self, x, gt_label, **kwargs):
pass
def forward_query(self, x, **kwargs):
pass
def before_forward_support(self) -> None:
pass
def before_forward_query(self) -> None:
pass
```
#### 2. Import the module
You can either add the following line to `mmfewshot/classification/models/heads/__init__.py`
```python
from .myhead import MyHead
```
or alternatively add
```python
custom_imports = dict(
imports=['mmfewshot.classification.models.backbones.myhead'],
allow_failed_imports=False)
```
to the config file to avoid modifying the original code.
#### 3. Use the head in your config file
```python
model = dict(
...
head=dict(
type='MyHead',
arg1=xxx,
arg2=xxx),
...
```
### Add new loss
To add a new loss function, the users need implement it in `mmfewshot/classification/models/losses/my_loss.py`.
The decorator `weighted_loss` enable the loss to be weighted for each element.
```python
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module()
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
```
Then the users need to add it in the `mmfewshot/classification/models/losses/__init__.py`.
```python
from .my_loss import MyLoss, my_loss
```
Alternatively, you can add
```python
custom_imports=dict(
imports=['mmfewshot.classification.models.losses.my_loss'])
```
to the config file and achieve the same goal.
To use it, modify the `loss_xxx` field.
Since MyLoss is for regression, you need to modify the `loss_bbox` field in the head.
```python
loss_bbox=dict(type='MyLoss', loss_weight=1.0))
```