diff --git a/fastreid/layers/arc_softmax.py b/fastreid/layers/arc_softmax.py
index 485444f..f806141 100644
--- a/fastreid/layers/arc_softmax.py
+++ b/fastreid/layers/arc_softmax.py
@@ -39,12 +39,14 @@ class ArcSoftmax(nn.Module):
         sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
         cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m  # cos(target+margin)
         mask = cos_theta > cos_theta_m
-        final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
+        final_target_logit = torch.where(target_logit > self.threshold,
+                                         cos_theta_m.to(target_logit),
+                                         target_logit - self.mm)
 
         hard_example = cos_theta[mask]
         with torch.no_grad():
             self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
-        cos_theta[mask] = hard_example * (self.t + hard_example)
+        cos_theta[mask] = hard_example * (self.t + hard_example).to(hard_example.dtype)
         cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
         pred_class_logits = cos_theta * self.s
         return pred_class_logits
diff --git a/fastreid/layers/pooling.py b/fastreid/layers/pooling.py
index 505a332..ec7796f 100644
--- a/fastreid/layers/pooling.py
+++ b/fastreid/layers/pooling.py
@@ -19,7 +19,7 @@ __all__ = ["Flatten",
 
 class Flatten(nn.Module):
     def forward(self, input):
-        return input.view(input.size(0), -1)
+        return input.view(input.size(0), -1, 1, 1)
 
 
 class GeneralizedMeanPooling(nn.Module):