Squeezenet reshape outputs fix (#10222)
@AyushExel Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9492/head^2
parent
40bb8030f8
commit
72cad39854
utils
|
@ -82,7 +82,7 @@ def reshape_classifier_output(model, n=1000):
|
|||
elif nn.Conv2d in types:
|
||||
i = types.index(nn.Conv2d) # nn.Conv2d index
|
||||
if m[i].out_channels != n:
|
||||
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias)
|
||||
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
Loading…
Reference in New Issue