mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
support loading deit weights (#538)
This commit is contained in:
parent
f253451b54
commit
1052f8d5d3
@ -325,6 +325,8 @@ class VisionTransformer(nn.Module):
|
||||
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
|
@ -54,7 +54,6 @@ class MultiLevelNeck(nn.Module):
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
print(inputs[0].shape)
|
||||
inputs = [
|
||||
lateral_conv(inputs[i])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user