mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Better handling of crossvit for tests / forward_features, fix torchscript regression in my changes
This commit is contained in:
parent
702982d8af
commit
7ab2491ab7
@ -188,25 +188,22 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
|||||||
|
|
||||||
input_tensor = torch.randn((batch_size, *input_size))
|
input_tensor = torch.randn((batch_size, *input_size))
|
||||||
|
|
||||||
# test forward_features (always unpooled)
|
outputs = model.forward_features(input_tensor)
|
||||||
if 'crossvit' not in model_name:
|
if isinstance(outputs, (tuple, list)):
|
||||||
# FIXME remove crossvit exception
|
outputs = outputs[0]
|
||||||
outputs = model.forward_features(input_tensor)
|
assert outputs.shape[1] == model.num_features
|
||||||
if isinstance(outputs, tuple):
|
|
||||||
outputs = outputs[0]
|
|
||||||
assert outputs.shape[1] == model.num_features
|
|
||||||
|
|
||||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||||
model.reset_classifier(0)
|
model.reset_classifier(0)
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, (tuple, list)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
assert len(outputs.shape) == 2
|
assert len(outputs.shape) == 2
|
||||||
assert outputs.shape[1] == model.num_features
|
assert outputs.shape[1] == model.num_features
|
||||||
|
|
||||||
model = create_model(model_name, pretrained=False, num_classes=0).eval()
|
model = create_model(model_name, pretrained=False, num_classes=0).eval()
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, (tuple, list)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
assert len(outputs.shape) == 2
|
assert len(outputs.shape) == 2
|
||||||
assert outputs.shape[1] == model.num_features
|
assert outputs.shape[1] == model.num_features
|
||||||
|
@ -268,12 +268,9 @@ class CrossViT(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if not isinstance(img_size, (tuple, list)):
|
self.img_size = to_2tuple(img_size)
|
||||||
img_size = to_2tuple(img_size)
|
img_scale = to_2tuple(img_scale)
|
||||||
self.img_size = img_size
|
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
|
||||||
if not isinstance(img_scale, (tuple, list)):
|
|
||||||
img_scale = to_2tuple(img_scale)
|
|
||||||
self.img_size_scaled = [tuple([int(sj * si) for sj in img_size]) for si in img_scale]
|
|
||||||
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
|
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
|
||||||
self.num_branches = len(patch_size)
|
self.num_branches = len(patch_size)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -346,7 +343,7 @@ class CrossViT(nn.Module):
|
|||||||
xs = []
|
xs = []
|
||||||
for i, patch_embed in enumerate(self.patch_embed):
|
for i, patch_embed in enumerate(self.patch_embed):
|
||||||
ss = self.img_size_scaled[i]
|
ss = self.img_size_scaled[i]
|
||||||
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic') if H != ss[0] else x
|
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) if H != ss[0] else x
|
||||||
tmp = patch_embed(x_)
|
tmp = patch_embed(x_)
|
||||||
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
||||||
cls_tokens = cls_tokens.expand(B, -1, -1)
|
cls_tokens = cls_tokens.expand(B, -1, -1)
|
||||||
@ -361,15 +358,12 @@ class CrossViT(nn.Module):
|
|||||||
|
|
||||||
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
||||||
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
||||||
return tuple([x[:, 0] for x in xs])
|
return [x[:, 0] for x in xs]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
xs = self.forward_features(x)
|
xs = self.forward_features(x)
|
||||||
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
|
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
|
||||||
if isinstance(self.head[0], nn.Identity):
|
if not isinstance(self.head[0], nn.Identity):
|
||||||
# FIXME to pass current passthrough features tests, could use better approach
|
|
||||||
ce_logits = tuple(ce_logits)
|
|
||||||
else:
|
|
||||||
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
|
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
|
||||||
return ce_logits
|
return ce_logits
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user