add more distillation models
parent
5a15c16581
commit
6cae5aafa1
|
@ -42,20 +42,25 @@ class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer):
|
|||
|
||||
def forward(self, input):
|
||||
teacher_label = self.teacher(input)
|
||||
teacher_label.stop_gradient = True
|
||||
|
||||
student_label = self.student(input)
|
||||
|
||||
return teacher_label, student_label
|
||||
|
||||
|
||||
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd():
|
||||
def net(self, input, class_dim=1000):
|
||||
# student
|
||||
student = ResNet50_vd()
|
||||
out_student = student.net(input, class_dim=class_dim)
|
||||
# teacher
|
||||
teacher = ResNeXt101_32x16d_wsl()
|
||||
out_teacher = teacher.net(input, class_dim=class_dim)
|
||||
out_teacher.stop_gradient = True
|
||||
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim=1000, **args):
|
||||
super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__()
|
||||
|
||||
return out_teacher, out_student
|
||||
self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args)
|
||||
|
||||
self.student = ResNet50_vd(class_dim=class_dim, **args)
|
||||
|
||||
def forward(self, input):
|
||||
teacher_label = self.teacher(input)
|
||||
teacher_label.stop_gradient = True
|
||||
|
||||
student_label = self.student(input)
|
||||
|
||||
return teacher_label, student_label
|
|
@ -58,7 +58,6 @@ def load_dygraph_pretrain(
|
|||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
weight_name = model_dict[key].name
|
||||
print("dyg key: {}, weight_name: {}".format(key, weight_name))
|
||||
if weight_name in pre_state_dict.keys():
|
||||
print('Load weight: {}, shape: {}'.format(
|
||||
weight_name, pre_state_dict[weight_name].shape))
|
||||
|
|
Loading…
Reference in New Issue