output_featuremaps_only -> return_featuremaps
parent
0224a5d433
commit
a20f8a66bd
|
@ -324,11 +324,11 @@ class Engine(object):
|
|||
|
||||
# forward to get convolutional feature maps
|
||||
try:
|
||||
outputs = self.model(imgs, output_featuremaps_only=True)
|
||||
outputs = self.model(imgs, return_featuremaps=True)
|
||||
except TypeError:
|
||||
raise TypeError('forward() got unexpected keyword argument "output_featuremaps_only". ' \
|
||||
'Please add output_featuremaps_only as an input argument to forward(). When ' \
|
||||
'output_featuremaps_only=True, return feature maps only.')
|
||||
raise TypeError('forward() got unexpected keyword argument "return_featuremaps". ' \
|
||||
'Please add return_featuremaps as an input argument to forward(). When ' \
|
||||
'return_featuremaps=True, return feature maps only.')
|
||||
|
||||
if outputs.dim() != 4:
|
||||
raise ValueError('The model output is supposed to have ' \
|
||||
|
|
|
@ -292,9 +292,9 @@ class OSNet(nn.Module):
|
|||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, output_featuremaps_only=False):
|
||||
def forward(self, x, return_featuremaps=False):
|
||||
x = self.featuremaps(x)
|
||||
if output_featuremaps_only:
|
||||
if return_featuremaps:
|
||||
return x
|
||||
v = self.global_avgpool(x)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
|
Loading…
Reference in New Issue