mirror of https://github.com/open-mmlab/mmyolo.git
fix gardcam error using in mmdet (#779)
parent
a3b6ae1d65
commit
1e5372e993
|
@ -202,8 +202,10 @@ class BoxAMDetectorWrapper(nn.Module):
|
||||||
if self.is_need_loss:
|
if self.is_need_loss:
|
||||||
# Maybe this is a direction that can be optimized
|
# Maybe this is a direction that can be optimized
|
||||||
# self.detector.init_weights()
|
# self.detector.init_weights()
|
||||||
|
if hasattr(self.detector.bbox_head, 'head_module'):
|
||||||
self.detector.bbox_head.head_module.training = True
|
self.detector.bbox_head.head_module.training = True
|
||||||
|
else:
|
||||||
|
self.detector.bbox_head.training = True
|
||||||
if hasattr(self.detector.bbox_head, 'featmap_sizes'):
|
if hasattr(self.detector.bbox_head, 'featmap_sizes'):
|
||||||
# Prevent the model algorithm error when calculating loss
|
# Prevent the model algorithm error when calculating loss
|
||||||
self.detector.bbox_head.featmap_sizes = None
|
self.detector.bbox_head.featmap_sizes = None
|
||||||
|
@ -219,7 +221,10 @@ class BoxAMDetectorWrapper(nn.Module):
|
||||||
|
|
||||||
return [loss]
|
return [loss]
|
||||||
else:
|
else:
|
||||||
|
if hasattr(self.detector.bbox_head, 'head_module'):
|
||||||
self.detector.bbox_head.head_module.training = False
|
self.detector.bbox_head.head_module.training = False
|
||||||
|
else:
|
||||||
|
self.detector.bbox_head.training = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
results = self.detector.test_step(self.input_data)
|
results = self.detector.test_step(self.input_data)
|
||||||
return results
|
return results
|
||||||
|
|
Loading…
Reference in New Issue