mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix backward in hgnet
This commit is contained in:
parent
6cd28bc5c2
commit
6862c9850a
@ -23,8 +23,8 @@ class LearnableAffineBlock(nn.Module):
|
|||||||
scale_value=1.0,
|
scale_value=1.0,
|
||||||
bias_value=0.0):
|
bias_value=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = nn.Parameter(torch.tensor([scale_value]))
|
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
|
||||||
self.bias = nn.Parameter(torch.tensor([bias_value]))
|
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.scale * x + self.bias
|
return self.scale * x + self.bias
|
||||||
@ -262,7 +262,7 @@ class HGBlock(nn.Module):
|
|||||||
x = torch.cat(output, dim=1)
|
x = torch.cat(output, dim=1)
|
||||||
x = self.aggregation(x)
|
x = self.aggregation(x)
|
||||||
if self.residual:
|
if self.residual:
|
||||||
x += identity
|
x = x + identity
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user