Merge branch 'main' into weight-twined-by-M

This commit is contained in:
Max 2023-06-20 01:08:12 +08:00 committed by GitHub
commit 0ed02778c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 175 additions and 174 deletions

View File

@ -1,3 +1,3 @@
sparsity:
mode: nxm
sparsity:
mode: nxm
choices: [[[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]], [[1, 8], [1, 4], [2, 4], [4, 4]]]

View File

@ -1,3 +1,3 @@
sparsity:
mode: nxm
choices: [[[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]], [[2, 4], [4, 4]]]
choices: [[[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]], [[1, 4], [2, 4], [4, 4]]]

View File

@ -40,7 +40,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
targets = targets.gt(0.0).type(targets.dtype)
with torch.cuda.amp.autocast():
outputs = model(samples)
outputs = model(samples, return_intermediate=(args.distillation_type == 'soft_fd'))
loss = criterion(samples, outputs, targets)
loss_value = loss.item()
@ -69,19 +69,19 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
@torch.no_grad()
def evaluate(nas_config, data_loader, model, device):
def evaluate(nas_config, data_loader, model, device, args = None):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# Sample the smallest subnetwork to test accuracy
smallest_config = []
for ratios in nas_config['sparsity']['choices']:
smallest_config.append(ratios[0])
# smallest_config.append([1, 3])
model.module.set_sample_config(smallest_config)
print(f'Evaluate Config: {smallest_config[0]}')
if args.nas_mode:
# Sample the smallest subnetwork to test accuracy
smallest_config = []
for ratios in nas_config['sparsity']['choices']:
smallest_config.append(ratios[0])
# smallest_config.append([1, 3])
model.module.set_sample_config(smallest_config)
# switch to evaluation mode
model.eval()

View File

@ -13,14 +13,15 @@ class DistillationLoss(torch.nn.Module):
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
distillation_type: str, alpha: float, tau: float, gamma: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
assert distillation_type in ['none', 'soft', 'soft_fd']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
self.gamma = gamma
def forward(self, inputs, outputs, labels):
"""
@ -31,40 +32,48 @@ class DistillationLoss(torch.nn.Module):
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
student_hidden_list = None
if self.distillation_type == 'soft_fd' and not isinstance(outputs, torch.Tensor):
outputs, student_hidden_list = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
if self.distillation_type == 'soft_fd' and student_hidden_list is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
"expected to return a Tuple[Tensor, [Tensor, Tensor....]] with the output of the "
"class_token and the list of intermediate outputs")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
teacher_outputs = self.teacher_model(inputs, return_intermediate = (self.distillation_type == 'soft_fd'))
teacher_hidden_list = None
if self.distillation_type == 'soft_fd' and not isinstance(teacher_outputs , torch.Tensor):
teacher_outputs, teacher_hidden_list = teacher_outputs
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs / T, dim=1),
#We provide the teacher's targets in log probability because we use log_target=True
#(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
#but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs.numel()
if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
#We provide the teacher's targets in log probability because we use log_target=True
#(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
#but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
#We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
#But we also experiments output_kd.size(0)
#see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return base_loss * (1 - self.alpha) + distillation_loss * self.alpha
# calculate hidden loss
layer_num = len(student_hidden_list)
hidden_loss = 0.
for student_hidden, teacher_hidden in zip(student_hidden_list, teacher_hidden_list):
hidden_loss += torch.nn.MSELoss()(student_hidden, teacher_hidden)
hidden_loss /= layer_num
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + self.gamma * hidden_loss
return loss

91
main.py
View File

@ -142,15 +142,19 @@ def get_args_parser():
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-path', type=str, default='')
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
# parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
# help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-model', default='deit_small_patch16_224', type=str, metavar='MODEL')
parser.add_argument('--teacher-path', type=str, default=None)
# parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'soft_fd'], type=str, help="")
# parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation-alpha', default=0.0, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
parser.add_argument('--distillation-gamma', default=0.1, type=float,
help="coefficient for hidden distillation loss, we set it to be 0.1 by aligning MiniViT")
# * Finetuning params
# parser.add_argument('--finetune', default='weights/deit_small_patch16_224-cd65a155.pth', help='finetune from checkpoint')
parser.add_argument('--finetune', default=None, help='finetune from checkpoint')
parser.add_argument('--attn-only', action='store_true')
@ -186,17 +190,24 @@ def get_args_parser():
# Sparsity Training Related Flag
# timm == 0.4.12
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --output_dir result_nas_1:4_150epoch_repeat
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir result_nas_124+13_150epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_nas_124+13.yaml --eval
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--nas-config', type=str, default=None, help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true', default=True)
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true', default=False)
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='twined_nas_124+13_150epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_2:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_124+13_150epoch/best_checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_1:8_100epoch/best_checkpoint.pth', help='load pretrained supernet weight')
parser.add_argument('--nas-weights', default=None, help='load pretrained supernet weight')
parser.add_argument('--wandb', action='store_true', default=True)
parser.add_argument('--wandb', action='store_true')
parser.add_argument('--output_dir', default='result',
help='path where to save, empty for no saving')
return parser
@ -295,12 +306,12 @@ def main(args):
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=True,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=None,
img_size=args.input_size
img_size=args.input_size,
)
if args.finetune:
@ -422,27 +433,37 @@ def main(args):
teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
# assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
# teacher_model = create_model( // regnety160
# args.teacher_model,
# pretrained=False,
# num_classes=args.nb_classes,
# global_pool='avg',
# )
teacher_model = create_model( # deit-small
args.teacher_model,
pretrained=True,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=None,
img_size=args.input_size
)
if args.teacher_path:
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()
# wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is 'none'
criterion = DistillationLoss(
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau, args.distillation_gamma
)
output_dir = Path(args.output_dir)
@ -463,7 +484,7 @@ def main(args):
loss_scaler.load_state_dict(checkpoint['scaler'])
lr_scheduler.step(args.start_epoch)
if args.eval:
test_stats = evaluate(nas_config, data_loader_val, model, device)
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
@ -496,7 +517,7 @@ def main(args):
}, checkpoint_path)
test_stats = evaluate(nas_config, data_loader_val, model, device)
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
@ -542,3 +563,15 @@ if __name__ == '__main__':
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
"""
python -m torch.distributed.launch \
--nproc_per_node=1 \
--use_env \
--master_port 29501 \
main.py \
--nas-config configs/deit_small_nxm_nas_124+13.yaml \
--data-path /work/shadowpa0327/imagenet \
--distillation-type soft_fd
"""

View File

@ -191,7 +191,7 @@ class SparseVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init='', ):
act_layer=None, weight_init=''):
"""
Args:
img_size (int, tuple): input image size
@ -204,7 +204,6 @@ class SparseVisionTransformer(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
@ -215,7 +214,7 @@ class SparseVisionTransformer(nn.Module):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
self.num_tokens = 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
@ -224,7 +223,6 @@ class SparseVisionTransformer(nn.Module):
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
@ -237,20 +235,10 @@ class SparseVisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
@ -269,47 +257,43 @@ class SparseVisionTransformer(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
return {'pos_embed', 'cls_token'}
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
def forward_features(self, x, return_intermediate = False):
intermediate_outputs = []
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
for b in self.blocks:
if return_intermediate:
intermediate_outputs.append(x)
x = b(x)
x = self.norm(x)
if self.dist_token is None:
if return_intermediate:
return self.pre_logits(x[:, 0]), intermediate_outputs
else:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
def forward(self, x, return_intermediate = False):
if return_intermediate:
x, intermedite_outputs = self.forward_features(x, return_intermediate)
x = self.head(x)
return x
return x, intermedite_outputs
else:
x = self.forward_features(x)
x = self.head(x)
return x
def set_seperate_config(self, seperate_configs):
# using after loading pre-trained weights / before evaluating trained supernet for sparsity
@ -336,7 +320,6 @@ class SparseVisionTransformer(nn.Module):
@register_model
def Sparse_deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).

View File

@ -87,15 +87,15 @@ default_cfgs = {
'vit_base_resnet50d_224': _cfg(),
}
class LRMlpSuper(nn.Module):
class MlpSuper(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.act = act_layer()
self.drop = nn.Dropout(drop)
self.fc1 = SparseLinearSuper(in_features, hidden_features)
self.fc2 = SparseLinearSuper(hidden_features, out_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
@ -105,15 +105,15 @@ class LRMlpSuper(nn.Module):
x = self.drop(x)
return x
class LRAttentionSuper(nn.Module):
class AttentionSuper(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.proj = SparseLinearSuper(dim, dim)
self.qkv = SparseLinearSuper(dim, dim * 3, bias = qkv_bias)
self.proj = nn.Linear(dim, dim)
self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
@ -137,12 +137,12 @@ class Block(nn.Module):
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LRAttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
self.attn = AttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LRMlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = MlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
@ -178,7 +178,7 @@ class PatchEmbed(nn.Module):
class SparseVisionTransformer(nn.Module):
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
@ -189,9 +189,9 @@ class SparseVisionTransformer(nn.Module):
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init='', ):
act_layer=None, weight_init=''):
"""
Args:
img_size (int, tuple): input image size
@ -204,7 +204,6 @@ class SparseVisionTransformer(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
@ -215,7 +214,7 @@ class SparseVisionTransformer(nn.Module):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
self.num_tokens = 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
@ -224,7 +223,6 @@ class SparseVisionTransformer(nn.Module):
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
@ -237,21 +235,10 @@ class SparseVisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
@ -269,52 +256,41 @@ class SparseVisionTransformer(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
return {'pos_embed', 'cls_token'}
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
def forward_features(self, x, return_intermediate = False):
intermediate_outputs = []
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
for b in self.blocks:
x = b(x)
if return_intermediate:
intermediate_outputs.append(x)
x = self.norm(x)
if self.dist_token is None:
if return_intermediate:
return self.pre_logits(x[:, 0]), intermediate_outputs
else:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
def forward(self, x, return_intermediate = False):
if return_intermediate:
x, intermedite_outputs = self.forward_features(x, return_intermediate)
x = self.head(x)
return x
return x, intermedite_outputs
else:
x = self.forward_features(x)
x = self.head(x)
return x
@ -323,7 +299,7 @@ def deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
@ -342,7 +318,7 @@ def deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
@ -360,7 +336,7 @@ def deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()