Update MAML (#36)
* fix init * fix test api fix test api bug * add metarcnn fsdetview config * update maml * update config * fix commentspull/1/head
parent
6c2144ec83
commit
7451fb7425
|
@ -263,12 +263,9 @@ def test_single_task(model, support_dataloader, query_dataloader,
|
|||
outputs['loss'].backward()
|
||||
optimizer.step()
|
||||
else: # methods without fine-tune stage
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i, data in enumerate(support_dataloader):
|
||||
data['gt_label'] = label_wrapper(data['gt_label'],
|
||||
task_class_ids)
|
||||
model.forward(**data, mode='support')
|
||||
for i, data in enumerate(support_dataloader):
|
||||
data['gt_label'] = label_wrapper(data['gt_label'], task_class_ids)
|
||||
model.forward(**data, mode='support')
|
||||
|
||||
model.before_forward_query()
|
||||
results_list, gt_label_list = [], []
|
||||
|
|
|
@ -59,7 +59,7 @@ def train_model(model,
|
|||
round_up=True,
|
||||
seed=cfg.seed,
|
||||
pin_memory=cfg.get('pin_memory', False),
|
||||
infinite_sampler=cfg.infinite_sampler) for ds in dataset
|
||||
use_infinite_sampler=cfg.use_infinite_sampler) for ds in dataset
|
||||
]
|
||||
|
||||
# put model on gpus
|
||||
|
@ -95,7 +95,7 @@ def train_model(model,
|
|||
else:
|
||||
if 'total_epochs' in cfg:
|
||||
assert cfg.total_epochs == cfg.runner.max_epochs
|
||||
if cfg.infinite_sampler and cfg.runner['type'] == 'EpochBasedRunner':
|
||||
if cfg.use_infinite_sampler and cfg.runner['type'] == 'EpochBasedRunner':
|
||||
cfg.runner['type'] = 'InfiniteEpochBasedRunner'
|
||||
runner = build_runner(
|
||||
cfg.runner,
|
||||
|
|
|
@ -51,7 +51,7 @@ def build_dataloader(dataset,
|
|||
round_up=True,
|
||||
seed=None,
|
||||
pin_memory=False,
|
||||
infinite_sampler=False,
|
||||
use_infinite_sampler=False,
|
||||
**kwargs):
|
||||
"""Build PyTorch DataLoader.
|
||||
|
||||
|
@ -73,10 +73,10 @@ def build_dataloader(dataset,
|
|||
seed (int | None): Random seed. Default:None.
|
||||
pin_memory (bool): Whether to use pin_memory for dataloader.
|
||||
Default: False.
|
||||
infinite_sampler (bool): Whether to use infinite sampler. Noted that
|
||||
infinite sampler will keep iterator of dataloader running
|
||||
forever, which can avoid the overhead of worker initialization
|
||||
between epochs. Default: False.
|
||||
use_infinite_sampler (bool): Whether to use infinite sampler.
|
||||
Noted that infinite sampler will keep iterator of dataloader
|
||||
running forever, which can avoid the overhead of worker
|
||||
initialization between epochs. Default: False.
|
||||
kwargs: any keyword argument to be used to initialize DataLoader
|
||||
|
||||
Returns:
|
||||
|
@ -84,7 +84,7 @@ def build_dataloader(dataset,
|
|||
"""
|
||||
rank, world_size = get_dist_info()
|
||||
if dist:
|
||||
if infinite_sampler:
|
||||
if use_infinite_sampler:
|
||||
sampler = DistributedInfiniteSampler(
|
||||
dataset, world_size, rank, shuffle=shuffle)
|
||||
else:
|
||||
|
@ -95,7 +95,7 @@ def build_dataloader(dataset,
|
|||
num_workers = workers_per_gpu
|
||||
else:
|
||||
sampler = InfiniteSampler(dataset, seed=seed, shuffle=shuffle) \
|
||||
if infinite_sampler else None
|
||||
if use_infinite_sampler else None
|
||||
batch_size = num_gpus * samples_per_gpu
|
||||
num_workers = num_gpus * workers_per_gpu
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ class ConvBlock(nn.Module):
|
|||
super(ConvBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=padding)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
layers = [self.conv, self.bn, self.relu]
|
||||
layers = [
|
||||
nn.Conv2d(in_channels, out_channels, 3, padding=padding),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
if is_pooling:
|
||||
layers.append(nn.MaxPool2d(2))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmcls.models.builder import CLASSIFIERS
|
||||
|
||||
from mmfewshot.classification.datasets import label_wrapper
|
||||
from mmfewshot.classification.models.utils import clone_module, update_module
|
||||
from mmfewshot.classification.models.utils import convert_maml_module
|
||||
from .base import FewShotBaseClassifier
|
||||
|
||||
|
||||
|
@ -26,6 +27,7 @@ class MAMLClassifier(FewShotBaseClassifier):
|
|||
self.num_inner_steps = num_inner_steps
|
||||
self.inner_lr = inner_lr
|
||||
self.first_order = first_order
|
||||
convert_maml_module(self)
|
||||
|
||||
def forward(self,
|
||||
img=None,
|
||||
|
@ -77,8 +79,8 @@ class MAMLClassifier(FewShotBaseClassifier):
|
|||
This method defines an iteration step during training, except for the
|
||||
back propagation and optimizer updating, which are done in an optimizer
|
||||
hook. Note that in some complicated cases or models, the whole process
|
||||
including back propagation and optimizer updating are also defined in
|
||||
this method, such as GAN.
|
||||
including back propagation and optimizer updating are also defined
|
||||
in this method, such as GAN.
|
||||
|
||||
Args:
|
||||
data (dict): The output of dataloader.
|
||||
|
@ -138,33 +140,17 @@ class MAMLClassifier(FewShotBaseClassifier):
|
|||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
support_img = support_data['img']
|
||||
support_img, query_img = support_data['img'], query_data['img']
|
||||
class_ids = torch.unique(support_data['gt_label']).cpu().tolist()
|
||||
np.random.shuffle(class_ids)
|
||||
support_label = label_wrapper(support_data['gt_label'], class_ids)
|
||||
clone_backbone = clone_module(self.backbone)
|
||||
clone_head = clone_module(self.head)
|
||||
second_order = not self.first_order
|
||||
for step in range(self.num_inner_steps):
|
||||
feats = clone_backbone(support_img)
|
||||
inner_loss = clone_head.forward_train(feats, support_label)['loss']
|
||||
parameters = list(clone_backbone.parameters()) + list(
|
||||
clone_head.parameters())
|
||||
grads = torch.autograd.grad(
|
||||
inner_loss,
|
||||
parameters,
|
||||
retain_graph=second_order,
|
||||
create_graph=second_order)
|
||||
for parameter, grad in zip(parameters, grads):
|
||||
if grad is not None:
|
||||
parameter.update = -self.inner_lr * grad
|
||||
update_module(clone_backbone)
|
||||
update_module(clone_head)
|
||||
|
||||
query_img = query_data['img']
|
||||
query_label = label_wrapper(query_data['gt_label'], class_ids)
|
||||
feats = clone_backbone(query_img)
|
||||
loss = clone_head.forward_train(feats, query_label)
|
||||
|
||||
self.fast_adapt(self.num_inner_steps, support_img, support_label)
|
||||
query_feats = self.extract_feat(query_img)
|
||||
loss = self.head.forward_train(query_feats, query_label)
|
||||
for weight in self.parameters():
|
||||
weight.fast = None
|
||||
return loss
|
||||
|
||||
def forward_support(self, img, gt_label, **kwargs):
|
||||
|
@ -178,8 +164,8 @@ class MAMLClassifier(FewShotBaseClassifier):
|
|||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
return self.head.forward_support(x, gt_label)
|
||||
self.fast_adapt(self.meta_test_cfg['support']['num_inner_steps'], img,
|
||||
gt_label)
|
||||
|
||||
def forward_query(self, img, **kwargs):
|
||||
"""Forward query data in meta testing.
|
||||
|
@ -194,12 +180,39 @@ class MAMLClassifier(FewShotBaseClassifier):
|
|||
x = self.extract_feat(img)
|
||||
return self.head.forward_query(x)
|
||||
|
||||
def fast_adapt(self, num_steps, img, labels):
|
||||
"""Forward and update fast weight with input images and labels.
|
||||
|
||||
Args:
|
||||
num_steps (int): The number of fast forward and update steps.
|
||||
img (Tensor): With shape (N, C, H, W).
|
||||
labels (Tensor): With shape (N).
|
||||
"""
|
||||
fast_parameters = list(self.parameters())
|
||||
for weight in self.parameters():
|
||||
weight.fast = None
|
||||
for step in range(num_steps):
|
||||
feats = self.extract_feat(img)
|
||||
inner_loss = self.head.forward_train(feats, labels)['loss']
|
||||
grads = torch.autograd.grad(
|
||||
inner_loss, fast_parameters, create_graph=True)
|
||||
fast_parameters = []
|
||||
if self.first_order:
|
||||
grads = [g.detach() for g in grads]
|
||||
for k, weight in enumerate(list(self.parameters())):
|
||||
if weight.fast is None:
|
||||
weight.fast = weight - self.inner_lr * grads[k]
|
||||
else:
|
||||
weight.fast = weight.fast - self.inner_lr * grads[k]
|
||||
fast_parameters.append(weight.fast)
|
||||
|
||||
def before_meta_test(self, meta_test_cfg, **kwargs):
|
||||
"""Used in meta testing.
|
||||
|
||||
This function will be called before the meta testing.
|
||||
"""
|
||||
self.meta_test_cfg = meta_test_cfg
|
||||
self.zero_grad()
|
||||
|
||||
def before_forward_support(self, **kwargs):
|
||||
"""Used in meta testing.
|
||||
|
@ -207,6 +220,8 @@ class MAMLClassifier(FewShotBaseClassifier):
|
|||
This function will be called before model forward support data during
|
||||
meta testing.
|
||||
"""
|
||||
for weight in self.parameters():
|
||||
weight.fast = None
|
||||
self.backbone.train()
|
||||
self.head.train()
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .module_utils import clone_module, update_module
|
||||
from .maml_module import convert_maml_module
|
||||
|
||||
__all__ = ['clone_module', 'update_module']
|
||||
__all__ = ['convert_maml_module']
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
"""Modified from https://github.com/wyharveychen/CloserLookFewShot and
|
||||
https://github.com/RL-VIG/LibFewShot.
|
||||
|
||||
This file is only used in method maml for fast adaptation.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LinearWithFastWeight(nn.Linear):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super(LinearWithFastWeight, self).__init__(in_features, out_features)
|
||||
# Lazy hack to add fast weight link
|
||||
self.weight.fast = None
|
||||
self.bias.fast = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight.fast is not None and self.bias.fast is not None:
|
||||
out = F.linear(x, self.weight.fast, self.bias.fast)
|
||||
else:
|
||||
out = super(LinearWithFastWeight, self).forward(x)
|
||||
return out
|
||||
|
||||
|
||||
class Conv2dWithFastWeight(nn.Conv2d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
):
|
||||
super(Conv2dWithFastWeight, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
# Lazy hack to add fast weight link
|
||||
self.weight.fast = None
|
||||
if self.bias is not None:
|
||||
self.bias.fast = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.bias is None:
|
||||
if self.weight.fast is not None:
|
||||
out = F.conv2d(
|
||||
x,
|
||||
self.weight.fast,
|
||||
None,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
else:
|
||||
out = super(Conv2dWithFastWeight, self).forward(x)
|
||||
else:
|
||||
if self.weight.fast is not None and self.bias.fast is not None:
|
||||
out = F.conv2d(
|
||||
x,
|
||||
self.weight.fast,
|
||||
self.bias.fast,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
else:
|
||||
out = super(Conv2dWithFastWeight, self).forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BatchNorm2dWithFastWeight(nn.BatchNorm2d):
|
||||
|
||||
def __init__(self, num_features):
|
||||
super(BatchNorm2dWithFastWeight, self).__init__(num_features)
|
||||
# Lazy hack to add fast weight link
|
||||
self.weight.fast = None
|
||||
self.bias.fast = None
|
||||
|
||||
def forward(self, x):
|
||||
# batch_norm momentum hack: follow hack of Kate
|
||||
# Rakelly in pytorch-maml/src/layers.py
|
||||
running_mean = torch.zeros(x.data.size()[1]).cuda()
|
||||
running_var = torch.ones(x.data.size()[1]).cuda()
|
||||
if self.weight.fast is not None and self.bias.fast is not None:
|
||||
out = F.batch_norm(
|
||||
x,
|
||||
running_mean,
|
||||
running_var,
|
||||
self.weight.fast,
|
||||
self.bias.fast,
|
||||
training=True,
|
||||
momentum=1,
|
||||
)
|
||||
else:
|
||||
out = F.batch_norm(
|
||||
x,
|
||||
running_mean,
|
||||
running_var,
|
||||
self.weight,
|
||||
self.bias,
|
||||
training=True,
|
||||
momentum=1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def convert_maml_module(module):
|
||||
"""Convert a normal model to MAML model.
|
||||
|
||||
Replace nn.Linear with LinearWithFastWeight, nn.Conv2d with
|
||||
Conv2dWithFastWeight and BatchNorm2d with BatchNorm2dWithFastWeight.
|
||||
|
||||
Args:
|
||||
module(nn.Module): The module to be converted.
|
||||
|
||||
Returns :
|
||||
nn.Module: A MAML module.
|
||||
"""
|
||||
converted_module = module
|
||||
if isinstance(module, torch.nn.modules.Linear):
|
||||
converted_module = LinearWithFastWeight(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
False if module.bias is None else True,
|
||||
)
|
||||
elif isinstance(module, torch.nn.modules.Conv2d):
|
||||
converted_module = Conv2dWithFastWeight(
|
||||
module.in_channels,
|
||||
module.out_channels,
|
||||
module.kernel_size,
|
||||
module.stride,
|
||||
module.padding,
|
||||
False if module.bias is None else True,
|
||||
)
|
||||
elif isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d):
|
||||
converted_module = BatchNorm2dWithFastWeight(module.num_features)
|
||||
for name, child in module.named_children():
|
||||
converted_module.add_module(name, convert_maml_module(child))
|
||||
del module
|
||||
return converted_module
|
|
@ -1,94 +0,0 @@
|
|||
import torch
|
||||
|
||||
# Used in MAML
|
||||
# modified from https://github.com/learnables/learn2learn
|
||||
|
||||
|
||||
def clone_module(module, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
return module
|
||||
clone = module.__new__(type(module))
|
||||
clone.__dict__ = module.__dict__.copy()
|
||||
clone._parameters = clone._parameters.copy()
|
||||
clone._buffers = clone._buffers.copy()
|
||||
clone._modules = clone._modules.copy()
|
||||
|
||||
# Second, re-write all parameters
|
||||
if hasattr(clone, '_parameters'):
|
||||
for param_key in module._parameters:
|
||||
if module._parameters[param_key] is not None:
|
||||
param = module._parameters[param_key]
|
||||
param_ptr = param.data_ptr
|
||||
if param_ptr in memo:
|
||||
clone._parameters[param_key] = memo[param_ptr]
|
||||
else:
|
||||
cloned = param.clone()
|
||||
clone._parameters[param_key] = cloned
|
||||
memo[param_ptr] = cloned
|
||||
|
||||
# Third, handle the buffers if necessary
|
||||
if hasattr(clone, '_buffers'):
|
||||
for buffer_key in module._buffers:
|
||||
if clone._buffers[buffer_key] is not None and \
|
||||
clone._buffers[buffer_key].requires_grad:
|
||||
buff = module._buffers[buffer_key]
|
||||
buff_ptr = buff.data_ptr
|
||||
if buff_ptr in memo:
|
||||
clone._buffers[buffer_key] = memo[buff_ptr]
|
||||
else:
|
||||
cloned = buff.clone()
|
||||
clone._buffers[buffer_key] = cloned
|
||||
memo[param_ptr] = cloned
|
||||
|
||||
# Then, recurse for each submodule
|
||||
if hasattr(clone, '_modules'):
|
||||
for module_key in clone._modules:
|
||||
clone._modules[module_key] = clone_module(
|
||||
module._modules[module_key],
|
||||
memo=memo,
|
||||
)
|
||||
|
||||
if hasattr(clone, 'flatten_parameters'):
|
||||
clone = clone._apply(lambda x: x)
|
||||
return clone
|
||||
|
||||
|
||||
def update_module(module, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
# Update the params
|
||||
for param_key in module._parameters:
|
||||
p = module._parameters[param_key]
|
||||
if p is not None and hasattr(p, 'update') and p.update is not None:
|
||||
if p in memo:
|
||||
module._parameters[param_key] = memo[p]
|
||||
else:
|
||||
updated = p + p.update
|
||||
memo[p] = updated
|
||||
module._parameters[param_key] = updated
|
||||
|
||||
# Second, handle the buffers if necessary
|
||||
for buffer_key in module._buffers:
|
||||
buff = module._buffers[buffer_key]
|
||||
if buff is not None and hasattr(buff, 'update') \
|
||||
and buff.update is not None:
|
||||
if buff in memo:
|
||||
module._buffers[buffer_key] = memo[buff]
|
||||
else:
|
||||
updated = buff + buff.update
|
||||
memo[buff] = updated
|
||||
module._buffers[buffer_key] = updated
|
||||
|
||||
# Then, recurse for each submodule
|
||||
for module_key in module._modules:
|
||||
module._modules[module_key] = update_module(
|
||||
module._modules[module_key],
|
||||
memo=memo,
|
||||
)
|
||||
|
||||
if hasattr(module, 'flatten_parameters'):
|
||||
module._apply(lambda x: x)
|
||||
return module
|
Loading…
Reference in New Issue