Update MAML (#36)

* fix init

* fix test api

fix test api bug

* add metarcnn fsdetview config

* update maml

* update config

* fix comments
pull/1/head
Linyiqi 2021-10-11 18:41:17 +08:00 committed by GitHub
parent 6c2144ec83
commit 7451fb7425
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 211 additions and 145 deletions

View File

@ -263,12 +263,9 @@ def test_single_task(model, support_dataloader, query_dataloader,
outputs['loss'].backward()
optimizer.step()
else: # methods without fine-tune stage
model.eval()
with torch.no_grad():
for i, data in enumerate(support_dataloader):
data['gt_label'] = label_wrapper(data['gt_label'],
task_class_ids)
model.forward(**data, mode='support')
for i, data in enumerate(support_dataloader):
data['gt_label'] = label_wrapper(data['gt_label'], task_class_ids)
model.forward(**data, mode='support')
model.before_forward_query()
results_list, gt_label_list = [], []

View File

@ -59,7 +59,7 @@ def train_model(model,
round_up=True,
seed=cfg.seed,
pin_memory=cfg.get('pin_memory', False),
infinite_sampler=cfg.infinite_sampler) for ds in dataset
use_infinite_sampler=cfg.use_infinite_sampler) for ds in dataset
]
# put model on gpus
@ -95,7 +95,7 @@ def train_model(model,
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
if cfg.infinite_sampler and cfg.runner['type'] == 'EpochBasedRunner':
if cfg.use_infinite_sampler and cfg.runner['type'] == 'EpochBasedRunner':
cfg.runner['type'] = 'InfiniteEpochBasedRunner'
runner = build_runner(
cfg.runner,

View File

@ -51,7 +51,7 @@ def build_dataloader(dataset,
round_up=True,
seed=None,
pin_memory=False,
infinite_sampler=False,
use_infinite_sampler=False,
**kwargs):
"""Build PyTorch DataLoader.
@ -73,10 +73,10 @@ def build_dataloader(dataset,
seed (int | None): Random seed. Default:None.
pin_memory (bool): Whether to use pin_memory for dataloader.
Default: False.
infinite_sampler (bool): Whether to use infinite sampler. Noted that
infinite sampler will keep iterator of dataloader running
forever, which can avoid the overhead of worker initialization
between epochs. Default: False.
use_infinite_sampler (bool): Whether to use infinite sampler.
Noted that infinite sampler will keep iterator of dataloader
running forever, which can avoid the overhead of worker
initialization between epochs. Default: False.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
@ -84,7 +84,7 @@ def build_dataloader(dataset,
"""
rank, world_size = get_dist_info()
if dist:
if infinite_sampler:
if use_infinite_sampler:
sampler = DistributedInfiniteSampler(
dataset, world_size, rank, shuffle=shuffle)
else:
@ -95,7 +95,7 @@ def build_dataloader(dataset,
num_workers = workers_per_gpu
else:
sampler = InfiniteSampler(dataset, seed=seed, shuffle=shuffle) \
if infinite_sampler else None
if use_infinite_sampler else None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu

View File

@ -8,11 +8,11 @@ class ConvBlock(nn.Module):
super(ConvBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=padding)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
layers = [self.conv, self.bn, self.relu]
layers = [
nn.Conv2d(in_channels, out_channels, 3, padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
if is_pooling:
layers.append(nn.MaxPool2d(2))
self.layers = nn.Sequential(*layers)

View File

@ -1,8 +1,9 @@
import numpy as np
import torch
from mmcls.models.builder import CLASSIFIERS
from mmfewshot.classification.datasets import label_wrapper
from mmfewshot.classification.models.utils import clone_module, update_module
from mmfewshot.classification.models.utils import convert_maml_module
from .base import FewShotBaseClassifier
@ -26,6 +27,7 @@ class MAMLClassifier(FewShotBaseClassifier):
self.num_inner_steps = num_inner_steps
self.inner_lr = inner_lr
self.first_order = first_order
convert_maml_module(self)
def forward(self,
img=None,
@ -77,8 +79,8 @@ class MAMLClassifier(FewShotBaseClassifier):
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating are also defined in
this method, such as GAN.
including back propagation and optimizer updating are also defined
in this method, such as GAN.
Args:
data (dict): The output of dataloader.
@ -138,33 +140,17 @@ class MAMLClassifier(FewShotBaseClassifier):
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
support_img = support_data['img']
support_img, query_img = support_data['img'], query_data['img']
class_ids = torch.unique(support_data['gt_label']).cpu().tolist()
np.random.shuffle(class_ids)
support_label = label_wrapper(support_data['gt_label'], class_ids)
clone_backbone = clone_module(self.backbone)
clone_head = clone_module(self.head)
second_order = not self.first_order
for step in range(self.num_inner_steps):
feats = clone_backbone(support_img)
inner_loss = clone_head.forward_train(feats, support_label)['loss']
parameters = list(clone_backbone.parameters()) + list(
clone_head.parameters())
grads = torch.autograd.grad(
inner_loss,
parameters,
retain_graph=second_order,
create_graph=second_order)
for parameter, grad in zip(parameters, grads):
if grad is not None:
parameter.update = -self.inner_lr * grad
update_module(clone_backbone)
update_module(clone_head)
query_img = query_data['img']
query_label = label_wrapper(query_data['gt_label'], class_ids)
feats = clone_backbone(query_img)
loss = clone_head.forward_train(feats, query_label)
self.fast_adapt(self.num_inner_steps, support_img, support_label)
query_feats = self.extract_feat(query_img)
loss = self.head.forward_train(query_feats, query_label)
for weight in self.parameters():
weight.fast = None
return loss
def forward_support(self, img, gt_label, **kwargs):
@ -178,8 +164,8 @@ class MAMLClassifier(FewShotBaseClassifier):
Returns:
dict[str, Tensor]: A dictionary of loss components
"""
x = self.extract_feat(img)
return self.head.forward_support(x, gt_label)
self.fast_adapt(self.meta_test_cfg['support']['num_inner_steps'], img,
gt_label)
def forward_query(self, img, **kwargs):
"""Forward query data in meta testing.
@ -194,12 +180,39 @@ class MAMLClassifier(FewShotBaseClassifier):
x = self.extract_feat(img)
return self.head.forward_query(x)
def fast_adapt(self, num_steps, img, labels):
"""Forward and update fast weight with input images and labels.
Args:
num_steps (int): The number of fast forward and update steps.
img (Tensor): With shape (N, C, H, W).
labels (Tensor): With shape (N).
"""
fast_parameters = list(self.parameters())
for weight in self.parameters():
weight.fast = None
for step in range(num_steps):
feats = self.extract_feat(img)
inner_loss = self.head.forward_train(feats, labels)['loss']
grads = torch.autograd.grad(
inner_loss, fast_parameters, create_graph=True)
fast_parameters = []
if self.first_order:
grads = [g.detach() for g in grads]
for k, weight in enumerate(list(self.parameters())):
if weight.fast is None:
weight.fast = weight - self.inner_lr * grads[k]
else:
weight.fast = weight.fast - self.inner_lr * grads[k]
fast_parameters.append(weight.fast)
def before_meta_test(self, meta_test_cfg, **kwargs):
"""Used in meta testing.
This function will be called before the meta testing.
"""
self.meta_test_cfg = meta_test_cfg
self.zero_grad()
def before_forward_support(self, **kwargs):
"""Used in meta testing.
@ -207,6 +220,8 @@ class MAMLClassifier(FewShotBaseClassifier):
This function will be called before model forward support data during
meta testing.
"""
for weight in self.parameters():
weight.fast = None
self.backbone.train()
self.head.train()

View File

@ -1,3 +1,3 @@
from .module_utils import clone_module, update_module
from .maml_module import convert_maml_module
__all__ = ['clone_module', 'update_module']
__all__ = ['convert_maml_module']

View File

@ -0,0 +1,148 @@
"""Modified from https://github.com/wyharveychen/CloserLookFewShot and
https://github.com/RL-VIG/LibFewShot.
This file is only used in method maml for fast adaptation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearWithFastWeight(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(LinearWithFastWeight, self).__init__(in_features, out_features)
# Lazy hack to add fast weight link
self.weight.fast = None
self.bias.fast = None
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast, self.bias.fast)
else:
out = super(LinearWithFastWeight, self).forward(x)
return out
class Conv2dWithFastWeight(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
bias=True,
):
super(Conv2dWithFastWeight, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
# Lazy hack to add fast weight link
self.weight.fast = None
if self.bias is not None:
self.bias.fast = None
def forward(self, x):
if self.bias is None:
if self.weight.fast is not None:
out = F.conv2d(
x,
self.weight.fast,
None,
stride=self.stride,
padding=self.padding,
)
else:
out = super(Conv2dWithFastWeight, self).forward(x)
else:
if self.weight.fast is not None and self.bias.fast is not None:
out = F.conv2d(
x,
self.weight.fast,
self.bias.fast,
stride=self.stride,
padding=self.padding,
)
else:
out = super(Conv2dWithFastWeight, self).forward(x)
return out
class BatchNorm2dWithFastWeight(nn.BatchNorm2d):
def __init__(self, num_features):
super(BatchNorm2dWithFastWeight, self).__init__(num_features)
# Lazy hack to add fast weight link
self.weight.fast = None
self.bias.fast = None
def forward(self, x):
# batch_norm momentum hack: follow hack of Kate
# Rakelly in pytorch-maml/src/layers.py
running_mean = torch.zeros(x.data.size()[1]).cuda()
running_var = torch.ones(x.data.size()[1]).cuda()
if self.weight.fast is not None and self.bias.fast is not None:
out = F.batch_norm(
x,
running_mean,
running_var,
self.weight.fast,
self.bias.fast,
training=True,
momentum=1,
)
else:
out = F.batch_norm(
x,
running_mean,
running_var,
self.weight,
self.bias,
training=True,
momentum=1,
)
return out
def convert_maml_module(module):
"""Convert a normal model to MAML model.
Replace nn.Linear with LinearWithFastWeight, nn.Conv2d with
Conv2dWithFastWeight and BatchNorm2d with BatchNorm2dWithFastWeight.
Args:
module(nn.Module): The module to be converted.
Returns :
nn.Module: A MAML module.
"""
converted_module = module
if isinstance(module, torch.nn.modules.Linear):
converted_module = LinearWithFastWeight(
module.in_features,
module.out_features,
False if module.bias is None else True,
)
elif isinstance(module, torch.nn.modules.Conv2d):
converted_module = Conv2dWithFastWeight(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
False if module.bias is None else True,
)
elif isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d):
converted_module = BatchNorm2dWithFastWeight(module.num_features)
for name, child in module.named_children():
converted_module.add_module(name, convert_maml_module(child))
del module
return converted_module

View File

@ -1,94 +0,0 @@
import torch
# Used in MAML
# modified from https://github.com/learnables/learn2learn
def clone_module(module, memo=None):
if memo is None:
memo = {}
if not isinstance(module, torch.nn.Module):
return module
clone = module.__new__(type(module))
clone.__dict__ = module.__dict__.copy()
clone._parameters = clone._parameters.copy()
clone._buffers = clone._buffers.copy()
clone._modules = clone._modules.copy()
# Second, re-write all parameters
if hasattr(clone, '_parameters'):
for param_key in module._parameters:
if module._parameters[param_key] is not None:
param = module._parameters[param_key]
param_ptr = param.data_ptr
if param_ptr in memo:
clone._parameters[param_key] = memo[param_ptr]
else:
cloned = param.clone()
clone._parameters[param_key] = cloned
memo[param_ptr] = cloned
# Third, handle the buffers if necessary
if hasattr(clone, '_buffers'):
for buffer_key in module._buffers:
if clone._buffers[buffer_key] is not None and \
clone._buffers[buffer_key].requires_grad:
buff = module._buffers[buffer_key]
buff_ptr = buff.data_ptr
if buff_ptr in memo:
clone._buffers[buffer_key] = memo[buff_ptr]
else:
cloned = buff.clone()
clone._buffers[buffer_key] = cloned
memo[param_ptr] = cloned
# Then, recurse for each submodule
if hasattr(clone, '_modules'):
for module_key in clone._modules:
clone._modules[module_key] = clone_module(
module._modules[module_key],
memo=memo,
)
if hasattr(clone, 'flatten_parameters'):
clone = clone._apply(lambda x: x)
return clone
def update_module(module, memo=None):
if memo is None:
memo = {}
# Update the params
for param_key in module._parameters:
p = module._parameters[param_key]
if p is not None and hasattr(p, 'update') and p.update is not None:
if p in memo:
module._parameters[param_key] = memo[p]
else:
updated = p + p.update
memo[p] = updated
module._parameters[param_key] = updated
# Second, handle the buffers if necessary
for buffer_key in module._buffers:
buff = module._buffers[buffer_key]
if buff is not None and hasattr(buff, 'update') \
and buff.update is not None:
if buff in memo:
module._buffers[buffer_key] = memo[buff]
else:
updated = buff + buff.update
memo[buff] = updated
module._buffers[buffer_key] = updated
# Then, recurse for each submodule
for module_key in module._modules:
module._modules[module_key] = update_module(
module._modules[module_key],
memo=memo,
)
if hasattr(module, 'flatten_parameters'):
module._apply(lambda x: x)
return module