Remove hook `torch.nan_to_num(x)` (#8826)
* Remove hook `torch.nan_to_num(x)` Observed erratic training behavior (green line) with the nan_to_num hook in classifier branch. I'm going to remove it from master. * Update train.pypull/8827/head
parent
59578f2782
commit
f3c78a387e
2
train.py
2
train.py
|
@ -131,7 +131,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = True # train all layers
|
||||
v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0.0
|
||||
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||
if any(x in k for x in freeze):
|
||||
LOGGER.info(f'freezing {k}')
|
||||
v.requires_grad = False
|
||||
|
|
Loading…
Reference in New Issue