fix backward in hgnet

This commit is contained in:
SeeFun 2023-12-27 16:49:37 +08:00 committed by GitHub
parent 6cd28bc5c2
commit 6862c9850a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,8 +23,8 @@ class LearnableAffineBlock(nn.Module):
scale_value=1.0,
bias_value=0.0):
super().__init__()
self.scale = nn.Parameter(torch.tensor([scale_value]))
self.bias = nn.Parameter(torch.tensor([bias_value]))
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
def forward(self, x):
return self.scale * x + self.bias
@ -262,7 +262,7 @@ class HGBlock(nn.Module):
x = torch.cat(output, dim=1)
x = self.aggregation(x)
if self.residual:
x += identity
x = x + identity
return x