support loading deit weights (#538)

This commit is contained in:
谢昕辰 2021-05-09 11:34:18 +08:00 committed by GitHub
parent f253451b54
commit 1052f8d5d3
2 changed files with 2 additions and 1 deletions

View File

@ -325,6 +325,8 @@ class VisionTransformer(nn.Module):
checkpoint = _load_checkpoint(pretrained, logger=logger)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint

View File

@ -54,7 +54,6 @@ class MultiLevelNeck(nn.Module):
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
print(inputs[0].shape)
inputs = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)