diff --git a/modeling/baseline.py b/modeling/baseline.py index 606c459..e59fa2e 100644 --- a/modeling/baseline.py +++ b/modeling/baseline.py @@ -171,6 +171,6 @@ class Baseline(nn.Module): def load_param(self, trained_path): param_dict = torch.load(trained_path) for i in param_dict: - # if 'classifier' in i: - # continue + if 'classifier' in i: + continue self.state_dict()[i].copy_(param_dict[i])