[Fix]: fix SWA in pytorch 1.6 (#312)
parent
bc763758d8
commit
e470c3aa1b
|
@ -129,7 +129,7 @@ class StochasticWeightAverage(BaseAveragedModel):
|
||||||
"""
|
"""
|
||||||
averaged_param.add_(
|
averaged_param.add_(
|
||||||
source_param - averaged_param,
|
source_param - averaged_param,
|
||||||
alpha=1 / (steps // self.interval + 1))
|
alpha=1 / float(steps // self.interval + 1))
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
|
Loading…
Reference in New Issue