From 1052f8d5d32d8c6fa034a6a816b0706778f625f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Sun, 9 May 2021 11:34:18 +0800 Subject: [PATCH] support loading deit weights (#538) --- mmseg/models/backbones/vit.py | 2 ++ mmseg/models/necks/multilevel_neck.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 377685722..781c9c1cc 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -325,6 +325,8 @@ class VisionTransformer(nn.Module): checkpoint = _load_checkpoint(pretrained, logger=logger) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] else: state_dict = checkpoint diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py index 7e13813b1..941b82992 100644 --- a/mmseg/models/necks/multilevel_neck.py +++ b/mmseg/models/necks/multilevel_neck.py @@ -54,7 +54,6 @@ class MultiLevelNeck(nn.Module): def forward(self, inputs): assert len(inputs) == len(self.in_channels) - print(inputs[0].shape) inputs = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)