Squeezenet reshape outputs fix ()

@AyushExel

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/9492/head^2
Glenn Jocher 2022-11-19 16:44:56 +01:00 committed by GitHub
parent 40bb8030f8
commit 72cad39854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -82,7 +82,7 @@ def reshape_classifier_output(model, n=1000):
elif nn.Conv2d in types:
i = types.index(nn.Conv2d) # nn.Conv2d index
if m[i].out_channels != n:
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias)
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
@contextmanager