Update MAML (#36)
* fix init * fix test api fix test api bug * add metarcnn fsdetview config * update maml * update config * fix commentspull/1/head
parent
6c2144ec83
commit
7451fb7425
|
@ -263,12 +263,9 @@ def test_single_task(model, support_dataloader, query_dataloader,
|
||||||
outputs['loss'].backward()
|
outputs['loss'].backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
else: # methods without fine-tune stage
|
else: # methods without fine-tune stage
|
||||||
model.eval()
|
for i, data in enumerate(support_dataloader):
|
||||||
with torch.no_grad():
|
data['gt_label'] = label_wrapper(data['gt_label'], task_class_ids)
|
||||||
for i, data in enumerate(support_dataloader):
|
model.forward(**data, mode='support')
|
||||||
data['gt_label'] = label_wrapper(data['gt_label'],
|
|
||||||
task_class_ids)
|
|
||||||
model.forward(**data, mode='support')
|
|
||||||
|
|
||||||
model.before_forward_query()
|
model.before_forward_query()
|
||||||
results_list, gt_label_list = [], []
|
results_list, gt_label_list = [], []
|
||||||
|
|
|
@ -59,7 +59,7 @@ def train_model(model,
|
||||||
round_up=True,
|
round_up=True,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
pin_memory=cfg.get('pin_memory', False),
|
pin_memory=cfg.get('pin_memory', False),
|
||||||
infinite_sampler=cfg.infinite_sampler) for ds in dataset
|
use_infinite_sampler=cfg.use_infinite_sampler) for ds in dataset
|
||||||
]
|
]
|
||||||
|
|
||||||
# put model on gpus
|
# put model on gpus
|
||||||
|
@ -95,7 +95,7 @@ def train_model(model,
|
||||||
else:
|
else:
|
||||||
if 'total_epochs' in cfg:
|
if 'total_epochs' in cfg:
|
||||||
assert cfg.total_epochs == cfg.runner.max_epochs
|
assert cfg.total_epochs == cfg.runner.max_epochs
|
||||||
if cfg.infinite_sampler and cfg.runner['type'] == 'EpochBasedRunner':
|
if cfg.use_infinite_sampler and cfg.runner['type'] == 'EpochBasedRunner':
|
||||||
cfg.runner['type'] = 'InfiniteEpochBasedRunner'
|
cfg.runner['type'] = 'InfiniteEpochBasedRunner'
|
||||||
runner = build_runner(
|
runner = build_runner(
|
||||||
cfg.runner,
|
cfg.runner,
|
||||||
|
|
|
@ -51,7 +51,7 @@ def build_dataloader(dataset,
|
||||||
round_up=True,
|
round_up=True,
|
||||||
seed=None,
|
seed=None,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
infinite_sampler=False,
|
use_infinite_sampler=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Build PyTorch DataLoader.
|
"""Build PyTorch DataLoader.
|
||||||
|
|
||||||
|
@ -73,10 +73,10 @@ def build_dataloader(dataset,
|
||||||
seed (int | None): Random seed. Default:None.
|
seed (int | None): Random seed. Default:None.
|
||||||
pin_memory (bool): Whether to use pin_memory for dataloader.
|
pin_memory (bool): Whether to use pin_memory for dataloader.
|
||||||
Default: False.
|
Default: False.
|
||||||
infinite_sampler (bool): Whether to use infinite sampler. Noted that
|
use_infinite_sampler (bool): Whether to use infinite sampler.
|
||||||
infinite sampler will keep iterator of dataloader running
|
Noted that infinite sampler will keep iterator of dataloader
|
||||||
forever, which can avoid the overhead of worker initialization
|
running forever, which can avoid the overhead of worker
|
||||||
between epochs. Default: False.
|
initialization between epochs. Default: False.
|
||||||
kwargs: any keyword argument to be used to initialize DataLoader
|
kwargs: any keyword argument to be used to initialize DataLoader
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -84,7 +84,7 @@ def build_dataloader(dataset,
|
||||||
"""
|
"""
|
||||||
rank, world_size = get_dist_info()
|
rank, world_size = get_dist_info()
|
||||||
if dist:
|
if dist:
|
||||||
if infinite_sampler:
|
if use_infinite_sampler:
|
||||||
sampler = DistributedInfiniteSampler(
|
sampler = DistributedInfiniteSampler(
|
||||||
dataset, world_size, rank, shuffle=shuffle)
|
dataset, world_size, rank, shuffle=shuffle)
|
||||||
else:
|
else:
|
||||||
|
@ -95,7 +95,7 @@ def build_dataloader(dataset,
|
||||||
num_workers = workers_per_gpu
|
num_workers = workers_per_gpu
|
||||||
else:
|
else:
|
||||||
sampler = InfiniteSampler(dataset, seed=seed, shuffle=shuffle) \
|
sampler = InfiniteSampler(dataset, seed=seed, shuffle=shuffle) \
|
||||||
if infinite_sampler else None
|
if use_infinite_sampler else None
|
||||||
batch_size = num_gpus * samples_per_gpu
|
batch_size = num_gpus * samples_per_gpu
|
||||||
num_workers = num_gpus * workers_per_gpu
|
num_workers = num_gpus * workers_per_gpu
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,11 @@ class ConvBlock(nn.Module):
|
||||||
super(ConvBlock, self).__init__()
|
super(ConvBlock, self).__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=padding)
|
layers = [
|
||||||
self.bn = nn.BatchNorm2d(out_channels)
|
nn.Conv2d(in_channels, out_channels, 3, padding=padding),
|
||||||
self.relu = nn.ReLU(inplace=True)
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
layers = [self.conv, self.bn, self.relu]
|
]
|
||||||
if is_pooling:
|
if is_pooling:
|
||||||
layers.append(nn.MaxPool2d(2))
|
layers.append(nn.MaxPool2d(2))
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcls.models.builder import CLASSIFIERS
|
from mmcls.models.builder import CLASSIFIERS
|
||||||
|
|
||||||
from mmfewshot.classification.datasets import label_wrapper
|
from mmfewshot.classification.datasets import label_wrapper
|
||||||
from mmfewshot.classification.models.utils import clone_module, update_module
|
from mmfewshot.classification.models.utils import convert_maml_module
|
||||||
from .base import FewShotBaseClassifier
|
from .base import FewShotBaseClassifier
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ class MAMLClassifier(FewShotBaseClassifier):
|
||||||
self.num_inner_steps = num_inner_steps
|
self.num_inner_steps = num_inner_steps
|
||||||
self.inner_lr = inner_lr
|
self.inner_lr = inner_lr
|
||||||
self.first_order = first_order
|
self.first_order = first_order
|
||||||
|
convert_maml_module(self)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
img=None,
|
img=None,
|
||||||
|
@ -77,8 +79,8 @@ class MAMLClassifier(FewShotBaseClassifier):
|
||||||
This method defines an iteration step during training, except for the
|
This method defines an iteration step during training, except for the
|
||||||
back propagation and optimizer updating, which are done in an optimizer
|
back propagation and optimizer updating, which are done in an optimizer
|
||||||
hook. Note that in some complicated cases or models, the whole process
|
hook. Note that in some complicated cases or models, the whole process
|
||||||
including back propagation and optimizer updating are also defined in
|
including back propagation and optimizer updating are also defined
|
||||||
this method, such as GAN.
|
in this method, such as GAN.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (dict): The output of dataloader.
|
data (dict): The output of dataloader.
|
||||||
|
@ -138,33 +140,17 @@ class MAMLClassifier(FewShotBaseClassifier):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: a dictionary of loss components
|
dict[str, Tensor]: a dictionary of loss components
|
||||||
"""
|
"""
|
||||||
|
support_img, query_img = support_data['img'], query_data['img']
|
||||||
support_img = support_data['img']
|
|
||||||
class_ids = torch.unique(support_data['gt_label']).cpu().tolist()
|
class_ids = torch.unique(support_data['gt_label']).cpu().tolist()
|
||||||
|
np.random.shuffle(class_ids)
|
||||||
support_label = label_wrapper(support_data['gt_label'], class_ids)
|
support_label = label_wrapper(support_data['gt_label'], class_ids)
|
||||||
clone_backbone = clone_module(self.backbone)
|
|
||||||
clone_head = clone_module(self.head)
|
|
||||||
second_order = not self.first_order
|
|
||||||
for step in range(self.num_inner_steps):
|
|
||||||
feats = clone_backbone(support_img)
|
|
||||||
inner_loss = clone_head.forward_train(feats, support_label)['loss']
|
|
||||||
parameters = list(clone_backbone.parameters()) + list(
|
|
||||||
clone_head.parameters())
|
|
||||||
grads = torch.autograd.grad(
|
|
||||||
inner_loss,
|
|
||||||
parameters,
|
|
||||||
retain_graph=second_order,
|
|
||||||
create_graph=second_order)
|
|
||||||
for parameter, grad in zip(parameters, grads):
|
|
||||||
if grad is not None:
|
|
||||||
parameter.update = -self.inner_lr * grad
|
|
||||||
update_module(clone_backbone)
|
|
||||||
update_module(clone_head)
|
|
||||||
|
|
||||||
query_img = query_data['img']
|
|
||||||
query_label = label_wrapper(query_data['gt_label'], class_ids)
|
query_label = label_wrapper(query_data['gt_label'], class_ids)
|
||||||
feats = clone_backbone(query_img)
|
|
||||||
loss = clone_head.forward_train(feats, query_label)
|
self.fast_adapt(self.num_inner_steps, support_img, support_label)
|
||||||
|
query_feats = self.extract_feat(query_img)
|
||||||
|
loss = self.head.forward_train(query_feats, query_label)
|
||||||
|
for weight in self.parameters():
|
||||||
|
weight.fast = None
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def forward_support(self, img, gt_label, **kwargs):
|
def forward_support(self, img, gt_label, **kwargs):
|
||||||
|
@ -178,8 +164,8 @@ class MAMLClassifier(FewShotBaseClassifier):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: A dictionary of loss components
|
dict[str, Tensor]: A dictionary of loss components
|
||||||
"""
|
"""
|
||||||
x = self.extract_feat(img)
|
self.fast_adapt(self.meta_test_cfg['support']['num_inner_steps'], img,
|
||||||
return self.head.forward_support(x, gt_label)
|
gt_label)
|
||||||
|
|
||||||
def forward_query(self, img, **kwargs):
|
def forward_query(self, img, **kwargs):
|
||||||
"""Forward query data in meta testing.
|
"""Forward query data in meta testing.
|
||||||
|
@ -194,12 +180,39 @@ class MAMLClassifier(FewShotBaseClassifier):
|
||||||
x = self.extract_feat(img)
|
x = self.extract_feat(img)
|
||||||
return self.head.forward_query(x)
|
return self.head.forward_query(x)
|
||||||
|
|
||||||
|
def fast_adapt(self, num_steps, img, labels):
|
||||||
|
"""Forward and update fast weight with input images and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_steps (int): The number of fast forward and update steps.
|
||||||
|
img (Tensor): With shape (N, C, H, W).
|
||||||
|
labels (Tensor): With shape (N).
|
||||||
|
"""
|
||||||
|
fast_parameters = list(self.parameters())
|
||||||
|
for weight in self.parameters():
|
||||||
|
weight.fast = None
|
||||||
|
for step in range(num_steps):
|
||||||
|
feats = self.extract_feat(img)
|
||||||
|
inner_loss = self.head.forward_train(feats, labels)['loss']
|
||||||
|
grads = torch.autograd.grad(
|
||||||
|
inner_loss, fast_parameters, create_graph=True)
|
||||||
|
fast_parameters = []
|
||||||
|
if self.first_order:
|
||||||
|
grads = [g.detach() for g in grads]
|
||||||
|
for k, weight in enumerate(list(self.parameters())):
|
||||||
|
if weight.fast is None:
|
||||||
|
weight.fast = weight - self.inner_lr * grads[k]
|
||||||
|
else:
|
||||||
|
weight.fast = weight.fast - self.inner_lr * grads[k]
|
||||||
|
fast_parameters.append(weight.fast)
|
||||||
|
|
||||||
def before_meta_test(self, meta_test_cfg, **kwargs):
|
def before_meta_test(self, meta_test_cfg, **kwargs):
|
||||||
"""Used in meta testing.
|
"""Used in meta testing.
|
||||||
|
|
||||||
This function will be called before the meta testing.
|
This function will be called before the meta testing.
|
||||||
"""
|
"""
|
||||||
self.meta_test_cfg = meta_test_cfg
|
self.meta_test_cfg = meta_test_cfg
|
||||||
|
self.zero_grad()
|
||||||
|
|
||||||
def before_forward_support(self, **kwargs):
|
def before_forward_support(self, **kwargs):
|
||||||
"""Used in meta testing.
|
"""Used in meta testing.
|
||||||
|
@ -207,6 +220,8 @@ class MAMLClassifier(FewShotBaseClassifier):
|
||||||
This function will be called before model forward support data during
|
This function will be called before model forward support data during
|
||||||
meta testing.
|
meta testing.
|
||||||
"""
|
"""
|
||||||
|
for weight in self.parameters():
|
||||||
|
weight.fast = None
|
||||||
self.backbone.train()
|
self.backbone.train()
|
||||||
self.head.train()
|
self.head.train()
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .module_utils import clone_module, update_module
|
from .maml_module import convert_maml_module
|
||||||
|
|
||||||
__all__ = ['clone_module', 'update_module']
|
__all__ = ['convert_maml_module']
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""Modified from https://github.com/wyharveychen/CloserLookFewShot and
|
||||||
|
https://github.com/RL-VIG/LibFewShot.
|
||||||
|
|
||||||
|
This file is only used in method maml for fast adaptation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class LinearWithFastWeight(nn.Linear):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
|
super(LinearWithFastWeight, self).__init__(in_features, out_features)
|
||||||
|
# Lazy hack to add fast weight link
|
||||||
|
self.weight.fast = None
|
||||||
|
self.bias.fast = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.weight.fast is not None and self.bias.fast is not None:
|
||||||
|
out = F.linear(x, self.weight.fast, self.bias.fast)
|
||||||
|
else:
|
||||||
|
out = super(LinearWithFastWeight, self).forward(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dWithFastWeight(nn.Conv2d):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
):
|
||||||
|
super(Conv2dWithFastWeight, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
# Lazy hack to add fast weight link
|
||||||
|
self.weight.fast = None
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.fast = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.bias is None:
|
||||||
|
if self.weight.fast is not None:
|
||||||
|
out = F.conv2d(
|
||||||
|
x,
|
||||||
|
self.weight.fast,
|
||||||
|
None,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = super(Conv2dWithFastWeight, self).forward(x)
|
||||||
|
else:
|
||||||
|
if self.weight.fast is not None and self.bias.fast is not None:
|
||||||
|
out = F.conv2d(
|
||||||
|
x,
|
||||||
|
self.weight.fast,
|
||||||
|
self.bias.fast,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = super(Conv2dWithFastWeight, self).forward(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm2dWithFastWeight(nn.BatchNorm2d):
|
||||||
|
|
||||||
|
def __init__(self, num_features):
|
||||||
|
super(BatchNorm2dWithFastWeight, self).__init__(num_features)
|
||||||
|
# Lazy hack to add fast weight link
|
||||||
|
self.weight.fast = None
|
||||||
|
self.bias.fast = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# batch_norm momentum hack: follow hack of Kate
|
||||||
|
# Rakelly in pytorch-maml/src/layers.py
|
||||||
|
running_mean = torch.zeros(x.data.size()[1]).cuda()
|
||||||
|
running_var = torch.ones(x.data.size()[1]).cuda()
|
||||||
|
if self.weight.fast is not None and self.bias.fast is not None:
|
||||||
|
out = F.batch_norm(
|
||||||
|
x,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
self.weight.fast,
|
||||||
|
self.bias.fast,
|
||||||
|
training=True,
|
||||||
|
momentum=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = F.batch_norm(
|
||||||
|
x,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
training=True,
|
||||||
|
momentum=1,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def convert_maml_module(module):
|
||||||
|
"""Convert a normal model to MAML model.
|
||||||
|
|
||||||
|
Replace nn.Linear with LinearWithFastWeight, nn.Conv2d with
|
||||||
|
Conv2dWithFastWeight and BatchNorm2d with BatchNorm2dWithFastWeight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module(nn.Module): The module to be converted.
|
||||||
|
|
||||||
|
Returns :
|
||||||
|
nn.Module: A MAML module.
|
||||||
|
"""
|
||||||
|
converted_module = module
|
||||||
|
if isinstance(module, torch.nn.modules.Linear):
|
||||||
|
converted_module = LinearWithFastWeight(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
False if module.bias is None else True,
|
||||||
|
)
|
||||||
|
elif isinstance(module, torch.nn.modules.Conv2d):
|
||||||
|
converted_module = Conv2dWithFastWeight(
|
||||||
|
module.in_channels,
|
||||||
|
module.out_channels,
|
||||||
|
module.kernel_size,
|
||||||
|
module.stride,
|
||||||
|
module.padding,
|
||||||
|
False if module.bias is None else True,
|
||||||
|
)
|
||||||
|
elif isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d):
|
||||||
|
converted_module = BatchNorm2dWithFastWeight(module.num_features)
|
||||||
|
for name, child in module.named_children():
|
||||||
|
converted_module.add_module(name, convert_maml_module(child))
|
||||||
|
del module
|
||||||
|
return converted_module
|
|
@ -1,94 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
# Used in MAML
|
|
||||||
# modified from https://github.com/learnables/learn2learn
|
|
||||||
|
|
||||||
|
|
||||||
def clone_module(module, memo=None):
|
|
||||||
if memo is None:
|
|
||||||
memo = {}
|
|
||||||
if not isinstance(module, torch.nn.Module):
|
|
||||||
return module
|
|
||||||
clone = module.__new__(type(module))
|
|
||||||
clone.__dict__ = module.__dict__.copy()
|
|
||||||
clone._parameters = clone._parameters.copy()
|
|
||||||
clone._buffers = clone._buffers.copy()
|
|
||||||
clone._modules = clone._modules.copy()
|
|
||||||
|
|
||||||
# Second, re-write all parameters
|
|
||||||
if hasattr(clone, '_parameters'):
|
|
||||||
for param_key in module._parameters:
|
|
||||||
if module._parameters[param_key] is not None:
|
|
||||||
param = module._parameters[param_key]
|
|
||||||
param_ptr = param.data_ptr
|
|
||||||
if param_ptr in memo:
|
|
||||||
clone._parameters[param_key] = memo[param_ptr]
|
|
||||||
else:
|
|
||||||
cloned = param.clone()
|
|
||||||
clone._parameters[param_key] = cloned
|
|
||||||
memo[param_ptr] = cloned
|
|
||||||
|
|
||||||
# Third, handle the buffers if necessary
|
|
||||||
if hasattr(clone, '_buffers'):
|
|
||||||
for buffer_key in module._buffers:
|
|
||||||
if clone._buffers[buffer_key] is not None and \
|
|
||||||
clone._buffers[buffer_key].requires_grad:
|
|
||||||
buff = module._buffers[buffer_key]
|
|
||||||
buff_ptr = buff.data_ptr
|
|
||||||
if buff_ptr in memo:
|
|
||||||
clone._buffers[buffer_key] = memo[buff_ptr]
|
|
||||||
else:
|
|
||||||
cloned = buff.clone()
|
|
||||||
clone._buffers[buffer_key] = cloned
|
|
||||||
memo[param_ptr] = cloned
|
|
||||||
|
|
||||||
# Then, recurse for each submodule
|
|
||||||
if hasattr(clone, '_modules'):
|
|
||||||
for module_key in clone._modules:
|
|
||||||
clone._modules[module_key] = clone_module(
|
|
||||||
module._modules[module_key],
|
|
||||||
memo=memo,
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(clone, 'flatten_parameters'):
|
|
||||||
clone = clone._apply(lambda x: x)
|
|
||||||
return clone
|
|
||||||
|
|
||||||
|
|
||||||
def update_module(module, memo=None):
|
|
||||||
if memo is None:
|
|
||||||
memo = {}
|
|
||||||
|
|
||||||
# Update the params
|
|
||||||
for param_key in module._parameters:
|
|
||||||
p = module._parameters[param_key]
|
|
||||||
if p is not None and hasattr(p, 'update') and p.update is not None:
|
|
||||||
if p in memo:
|
|
||||||
module._parameters[param_key] = memo[p]
|
|
||||||
else:
|
|
||||||
updated = p + p.update
|
|
||||||
memo[p] = updated
|
|
||||||
module._parameters[param_key] = updated
|
|
||||||
|
|
||||||
# Second, handle the buffers if necessary
|
|
||||||
for buffer_key in module._buffers:
|
|
||||||
buff = module._buffers[buffer_key]
|
|
||||||
if buff is not None and hasattr(buff, 'update') \
|
|
||||||
and buff.update is not None:
|
|
||||||
if buff in memo:
|
|
||||||
module._buffers[buffer_key] = memo[buff]
|
|
||||||
else:
|
|
||||||
updated = buff + buff.update
|
|
||||||
memo[buff] = updated
|
|
||||||
module._buffers[buffer_key] = updated
|
|
||||||
|
|
||||||
# Then, recurse for each submodule
|
|
||||||
for module_key in module._modules:
|
|
||||||
module._modules[module_key] = update_module(
|
|
||||||
module._modules[module_key],
|
|
||||||
memo=memo,
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(module, 'flatten_parameters'):
|
|
||||||
module._apply(lambda x: x)
|
|
||||||
return module
|
|
Loading…
Reference in New Issue