diff --git a/mmyolo/utils/boxam_utils.py b/mmyolo/utils/boxam_utils.py index 4a46f21c..50d6c09e 100644 --- a/mmyolo/utils/boxam_utils.py +++ b/mmyolo/utils/boxam_utils.py @@ -202,8 +202,10 @@ class BoxAMDetectorWrapper(nn.Module): if self.is_need_loss: # Maybe this is a direction that can be optimized # self.detector.init_weights() - - self.detector.bbox_head.head_module.training = True + if hasattr(self.detector.bbox_head, 'head_module'): + self.detector.bbox_head.head_module.training = True + else: + self.detector.bbox_head.training = True if hasattr(self.detector.bbox_head, 'featmap_sizes'): # Prevent the model algorithm error when calculating loss self.detector.bbox_head.featmap_sizes = None @@ -219,7 +221,10 @@ class BoxAMDetectorWrapper(nn.Module): return [loss] else: - self.detector.bbox_head.head_module.training = False + if hasattr(self.detector.bbox_head, 'head_module'): + self.detector.bbox_head.head_module.training = False + else: + self.detector.bbox_head.training = False with torch.no_grad(): results = self.detector.test_step(self.input_data) return results