From 6972dbafd0a113df2feb4a01391d4b27dbd4477b Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Thu, 10 May 2018 11:46:59 +0100 Subject: [PATCH] add new model --- models/InceptionResNetV2.py | 62 +- models/ResNeXt.py | 1489 +++++++++++++++++++++++++++++++++++ models/ResNet.py | 30 +- models/SEResNet.py | 90 ++- models/__init__.py | 14 +- 5 files changed, 1647 insertions(+), 38 deletions(-) create mode 100644 models/ResNeXt.py diff --git a/models/InceptionResNetV2.py b/models/InceptionResNetV2.py index ba867ec..9045227 100644 --- a/models/InceptionResNetV2.py +++ b/models/InceptionResNetV2.py @@ -232,9 +232,38 @@ class Block8(nn.Module): out = self.relu(out) return out +def inceptionresnetv2(num_classes=1000, pretrained='imagenet'): + r"""InceptionResNetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." `_ paper. + """ + if pretrained: + settings = pretrained_settings['inceptionresnetv2'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = InceptionResNetV2(num_classes=1001) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + if pretrained == 'imagenet': + new_last_linear = nn.Linear(1536, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + else: + model = InceptionResNetV2(num_classes=num_classes) + return model + +##################### Model Definition ######################### class InceptionResNetV2(nn.Module): - def __init__(self, num_classes, loss={'xent'}, **kwargs): super(InceptionResNetV2, self).__init__() self.loss = loss @@ -347,33 +376,4 @@ class InceptionResNetV2(nn.Module): elif self.loss == {'ring'}: return y, x else: - raise KeyError("Unsupported loss: {}".format(self.loss)) - -def inceptionresnetv2(num_classes=1000, pretrained='imagenet'): - r"""InceptionResNetV2 model architecture from the - `"InceptionV4, Inception-ResNet..." `_ paper. - """ - if pretrained: - settings = pretrained_settings['inceptionresnetv2'][pretrained] - assert num_classes == settings['num_classes'], \ - "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) - - # both 'imagenet'&'imagenet+background' are loaded from same parameters - model = InceptionResNetV2(num_classes=1001) - model.load_state_dict(model_zoo.load_url(settings['url'])) - - if pretrained == 'imagenet': - new_last_linear = nn.Linear(1536, 1000) - new_last_linear.weight.data = model.last_linear.weight.data[1:] - new_last_linear.bias.data = model.last_linear.bias.data[1:] - model.last_linear = new_last_linear - - model.input_space = settings['input_space'] - model.input_size = settings['input_size'] - model.input_range = settings['input_range'] - - model.mean = settings['mean'] - model.std = settings['std'] - else: - model = InceptionResNetV2(num_classes=num_classes) - return model \ No newline at end of file + raise KeyError("Unsupported loss: {}".format(self.loss)) \ No newline at end of file diff --git a/models/ResNeXt.py b/models/ResNeXt.py new file mode 100644 index 0000000..9706ad5 --- /dev/null +++ b/models/ResNeXt.py @@ -0,0 +1,1489 @@ +from __future__ import absolute_import + +import torch +import torch.nn as nn +from torch.nn import functional as F +from functools import reduce +import torch.utils.model_zoo as model_zoo + +import os + +__all__ = ['ResNeXt101_32x4d'] + +""" +Code imported from https://github.com/Cadene/pretrained-models.pytorch +""" + +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input + +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) + +class LambdaMap(LambdaBase): + def forward(self, input): + return list(map(self.lambda_func,self.forward_prepare(input))) + +class LambdaReduce(LambdaBase): + def forward(self, input): + return reduce(self.lambda_func,self.forward_prepare(input)) + +# JUMP TO THE END ######################################################################### +resnext101_32x4d_features = nn.Sequential( # Sequential, + nn.Conv2d(3,64,(7, 7),(2, 2),(3, 3),1,1,bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d((3, 3),(2, 2),(1, 1)), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(64,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + ), + nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + nn.Sequential( # Sequential, + nn.Conv2d(64,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + ), + nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(256,128,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(128), + nn.ReLU(), + ), + nn.Conv2d(128,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(256,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(2, 2),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + nn.Sequential( # Sequential, + nn.Conv2d(256,512,(1, 1),(2, 2),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,256,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(512,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(2, 2),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + nn.Sequential( # Sequential, + nn.Conv2d(512,1024,(1, 1),(2, 2),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,512,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(1024,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024,1024,(3, 3),(2, 2),(1, 1),1,32,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + nn.Sequential( # Sequential, + nn.Conv2d(1024,2048,(1, 1),(2, 2),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + nn.Conv2d(2048,1024,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024,1024,(3, 3),(1, 1),(1, 1),1,32,bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024,2048,(1, 1),(1, 1),(0, 0),1,1,bias=False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x,y: x+y), # CAddTable, + nn.ReLU(), + ), + ) +) + +################################################################################# +resnext101_64x4d_features = nn.Sequential(#Sequential, + nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias = False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d((3, 3), (2, 2), (1, 1)), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + nn.Sequential(#Sequential, + nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(256), + nn.ReLU(), + ), + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + nn.Sequential(#Sequential, + nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(512), + nn.ReLU(), + ), + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + nn.Sequential(#Sequential, + nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(1024), + nn.ReLU(), + ), + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ), + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + nn.Conv2d(2048, 2048, (3, 3), (2, 2), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + ), + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + nn.Sequential(#Sequential, + nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + nn.Conv2d(2048, 2048, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + ), + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + nn.Sequential(#Sequential, + LambdaMap(lambda x: x, #ConcatTable, + nn.Sequential(#Sequential, + nn.Sequential(#Sequential, + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + nn.Conv2d(2048, 2048, (3, 3), (1, 1), (1, 1), 1, 64, bias = False), + nn.BatchNorm2d(2048), + nn.ReLU(), + ), + nn.Conv2d(2048, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias = False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), #Identity, + ), + LambdaReduce(lambda x, y: x + y), #CAddTable, + nn.ReLU(), + ), + ) +) + +################################################################################# + +pretrained_settings = { + 'resnext101_32x4d': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/resnext101_32x4d-29e315fa.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'resnext101_64x4d': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/resnext101_64x4d-e77a0586.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + } +} + +def resnext101_32x4d(num_classes=1000, pretrained='imagenet'): + """Deprecated""" + model = ResNeXt101_32x4d(num_classes=num_classes) + if pretrained is not None: + settings = pretrained_settings['resnext101_32x4d'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + model.load_state_dict(model_zoo.load_url(settings['url'])) + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + return model + +def resnext101_64x4d(num_classes=1000, pretrained='imagenet'): + """Deprecated""" + model = ResNeXt101_64x4d(num_classes=num_classes) + if pretrained is not None: + settings = pretrained_settings['resnext101_64x4d'][pretrained] + assert num_classes == settings['num_classes'], \ + "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) + model.load_state_dict(model_zoo.load_url(settings['url'])) + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + model.mean = settings['mean'] + model.std = settings['std'] + return model + +##################### Model Definition ######################### + +class ResNeXt101_32x4d(nn.Module): + def __init__(self, num_classes, loss={'xent'}, **kwargs): + super(ResNeXt101_32x4d, self).__init__() + self.loss = loss + self.features = resnext101_32x4d_features + self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 + self.init_params() + + def init_params(self): + """Load ImageNet pretrained weights""" + settings = pretrained_settings['resnext101_32x4d']['imagenet'] + pretrained_dict = model_zoo.load_url(settings['url'], map_location=None) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + + def forward(self, input): + x = self.features(input) + x = F.avg_pool2d(x, x.size()[2:]) + x = x.view(x.size(0), -1) + + if not self.training: + return x + + y = self.classifier(x) + + if self.loss == {'xent'}: + return y + elif self.loss == {'xent', 'htri'}: + return y, x + elif self.loss == {'cent'}: + return y, x + elif self.loss == {'ring'}: + return y, x + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +class ResNeXt101_64x4d(nn.Module): + """This model is not used""" + def __init__(self, num_classes, loss={'xent'}, **kwargs): + super(ResNeXt101_64x4d, self).__init__() + self.loss = loss + self.features = resnext101_64x4d_features + self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 + self.init_params() + + def init_params(self): + """Load ImageNet pretrained weights""" + settings = pretrained_settings['resnext101_64x4d']['imagenet'] + pretrained_dict = model_zoo.load_url(settings['url'], map_location=None) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + + def forward(self, input): + x = self.features(input) + x = F.avg_pool2d(x, x.size()[2:]) + x = x.view(x.size(0), -1) + print x.size() + + if not self.training: + return x + + y = self.classifier(x) + print y.size() + + if self.loss == {'xent'}: + return y + elif self.loss == {'xent', 'htri'}: + return y, x + elif self.loss == {'cent'}: + return y, x + elif self.loss == {'ring'}: + return y, x + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) \ No newline at end of file diff --git a/models/ResNet.py b/models/ResNet.py index 457aecf..3d25775 100755 --- a/models/ResNet.py +++ b/models/ResNet.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F import torchvision -__all__ = ['ResNet50', 'ResNet50M'] +__all__ = ['ResNet50', 'ResNet101', 'ResNet50M'] class ResNet50(nn.Module): def __init__(self, num_classes, loss={'xent'}, **kwargs): @@ -35,6 +35,34 @@ class ResNet50(nn.Module): else: raise KeyError("Unsupported loss: {}".format(self.loss)) +class ResNet101(nn.Module): + def __init__(self, num_classes, loss={'xent'}, **kwargs): + super(ResNet101, self).__init__() + self.loss = loss + resnet101 = torchvision.models.resnet101(pretrained=True) + self.base = nn.Sequential(*list(resnet101.children())[:-2]) + self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 # feature dimension + + def forward(self, x): + x = self.base(x) + x = F.avg_pool2d(x, x.size()[2:]) + f = x.view(x.size(0), -1) + if not self.training: + return f + y = self.classifier(f) + + if self.loss == {'xent'}: + return y + elif self.loss == {'xent', 'htri'}: + return y, f + elif self.loss == {'cent'}: + return y, f + elif self.loss == {'ring'}: + return y, f + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + class ResNet50M(nn.Module): """ResNet50 + mid-level features. diff --git a/models/SEResNet.py b/models/SEResNet.py index 30f1e5c..7f39fec 100644 --- a/models/SEResNet.py +++ b/models/SEResNet.py @@ -13,7 +13,7 @@ import torchvision Code imported from https://github.com/Cadene/pretrained-models.pytorch """ -__all__ = ['SEResNet50'] +__all__ = ['SEResNet50', 'SEResNet101', 'SEResNeXt50', 'SEResNeXt101'] pretrained_settings = { 'senet154': { @@ -190,7 +190,7 @@ class SEResNeXtBottleneck(Bottleneck): def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): super(SEResNeXtBottleneck, self).__init__() - width = math.floor(planes * (base_width / 64)) * groups + width = int(math.floor(planes * (base_width / 64.)) * groups) self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) self.bn1 = nn.BatchNorm2d(width) @@ -442,6 +442,8 @@ def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): initialize_pretrained_model(model, num_classes, settings) return model +##################### Model Definition ######################### + class SEResNet50(nn.Module): def __init__(self, num_classes, loss={'xent'}, **kwargs): super(SEResNet50, self).__init__() @@ -451,6 +453,90 @@ class SEResNet50(nn.Module): self.classifier = nn.Linear(2048, num_classes) self.feat_dim = 2048 # feature dimension + def forward(self, x): + x = self.base(x) + x = F.avg_pool2d(x, x.size()[2:]) + f = x.view(x.size(0), -1) + if not self.training: + return f + y = self.classifier(f) + + if self.loss == {'xent'}: + return y + elif self.loss == {'xent', 'htri'}: + return y, f + elif self.loss == {'cent'}: + return y, f + elif self.loss == {'ring'}: + return y, f + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + +class SEResNet101(nn.Module): + def __init__(self, num_classes, loss={'xent'}, **kwargs): + super(SEResNet101, self).__init__() + self.loss = loss + base = se_resnet101() + self.base = nn.Sequential(*list(base.children())[:-2]) + self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 # feature dimension + + def forward(self, x): + x = self.base(x) + x = F.avg_pool2d(x, x.size()[2:]) + f = x.view(x.size(0), -1) + if not self.training: + return f + y = self.classifier(f) + + if self.loss == {'xent'}: + return y + elif self.loss == {'xent', 'htri'}: + return y, f + elif self.loss == {'cent'}: + return y, f + elif self.loss == {'ring'}: + return y, f + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + +class SEResNeXt50(nn.Module): + def __init__(self, num_classes, loss={'xent'}, **kwargs): + super(SEResNeXt50, self).__init__() + self.loss = loss + base = se_resnext50_32x4d() + self.base = nn.Sequential(*list(base.children())[:-2]) + self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 # feature dimension + + def forward(self, x): + x = self.base(x) + x = F.avg_pool2d(x, x.size()[2:]) + f = x.view(x.size(0), -1) + if not self.training: + return f + y = self.classifier(f) + + if self.loss == {'xent'}: + return y + elif self.loss == {'xent', 'htri'}: + return y, f + elif self.loss == {'cent'}: + return y, f + elif self.loss == {'ring'}: + return y, f + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + +class SEResNeXt101(nn.Module): + def __init__(self, num_classes, loss={'xent'}, **kwargs): + super(SEResNeXt101, self).__init__() + self.loss = loss + base = se_resnext101_32x4d() + self.base = nn.Sequential(*list(base.children())[:-2]) + self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 # feature dimension + def forward(self, x): x = self.base(x) x = F.avg_pool2d(x, x.size()[2:]) diff --git a/models/__init__.py b/models/__init__.py index 8b804b0..ec844de 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,6 +1,8 @@ from __future__ import absolute_import from .ResNet import * +from .ResNeXt import * +from .SEResNet import * from .DenseNet import * from .MuDeep import * from .HACNN import * @@ -9,26 +11,30 @@ from .MobileNet import * from .ShuffleNet import * from .Xception import * from .InceptionV4 import * -from .SEResNet import * from .NASNet import * from .DPN import * from .InceptionResNetV2 import * __factory = { 'resnet50': ResNet50, + 'resnet101': ResNet101, + 'seresnet50': SEResNet50, + 'seresnet101': SEResNet101, + 'seresnext50': SEResNeXt50, + 'seresnext101': SEResNeXt101, + 'resnext101': ResNeXt101_32x4d, 'resnet50m': ResNet50M, 'densenet121': DenseNet121, - 'mudeep': MuDeep, - 'hacnn': HACNN, 'squeezenet': SqueezeNet, 'mobilenet': MobileNetV2, 'shufflenet': ShuffleNet, 'xception': Xception, 'inceptionv4': InceptionV4ReID, - 'seresnet50': SEResNet50, 'nasnet': NASNetAMobile, 'dpn92': DPN, 'inceptionresnetv2': InceptionResNetV2, + 'mudeep': MuDeep, + 'hacnn': HACNN, } def get_names():