mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
fix arcSoftmax fp16 training problem
Summary: fixup fp16 training when using arcSoftmax by aligning the data type
This commit is contained in:
parent
20a01f2545
commit
fe2e46d40e
@ -39,12 +39,14 @@ class ArcSoftmax(nn.Module):
|
|||||||
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
|
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
|
||||||
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
|
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
|
||||||
mask = cos_theta > cos_theta_m
|
mask = cos_theta > cos_theta_m
|
||||||
final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
|
final_target_logit = torch.where(target_logit > self.threshold,
|
||||||
|
cos_theta_m.to(target_logit),
|
||||||
|
target_logit - self.mm)
|
||||||
|
|
||||||
hard_example = cos_theta[mask]
|
hard_example = cos_theta[mask]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
|
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
|
||||||
cos_theta[mask] = hard_example * (self.t + hard_example)
|
cos_theta[mask] = hard_example * (self.t + hard_example).to(hard_example.dtype)
|
||||||
cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
|
cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
|
||||||
pred_class_logits = cos_theta * self.s
|
pred_class_logits = cos_theta * self.s
|
||||||
return pred_class_logits
|
return pred_class_logits
|
||||||
|
@ -19,7 +19,7 @@ __all__ = ["Flatten",
|
|||||||
|
|
||||||
class Flatten(nn.Module):
|
class Flatten(nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return input.view(input.size(0), -1)
|
return input.view(input.size(0), -1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
class GeneralizedMeanPooling(nn.Module):
|
class GeneralizedMeanPooling(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user