mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pooling in mnasnet, more sensible default for AMP opt level
This commit is contained in:
parent
996c77aa94
commit
e9c7961efc
@ -49,7 +49,7 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
||||
return x
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(torch.nn.Module):
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
@ -58,7 +58,7 @@ class AdaptiveAvgMaxPool2d(torch.nn.Module):
|
||||
return adaptive_avgmax_pool2d(x, self.output_size)
|
||||
|
||||
|
||||
class AdaptiveCatAvgMaxPool2d(torch.nn.Module):
|
||||
class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
super(AdaptiveCatAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
@ -67,7 +67,7 @@ class AdaptiveCatAvgMaxPool2d(torch.nn.Module):
|
||||
return adaptive_catavgmax_pool2d(x, self.output_size)
|
||||
|
||||
|
||||
class SelectAdaptivePool2d(torch.nn.Module):
|
||||
class SelectAdaptivePool2d(nn.Module):
|
||||
"""Selectable global pooling layer with dynamic input kernel size
|
||||
"""
|
||||
def __init__(self, output_size=1, pool_type='avg'):
|
||||
|
@ -185,7 +185,6 @@ class MnasBlock(nn.Module):
|
||||
# Pointwise projection
|
||||
x = self.conv_project(x)
|
||||
x = self.bn2(x)
|
||||
# Residual
|
||||
if self.has_residual:
|
||||
return x + residual
|
||||
else:
|
||||
@ -268,7 +267,7 @@ class MnasNet(nn.Module):
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
if pool:
|
||||
x = self.avg_pool(x)
|
||||
x = self.global_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
|
5
train.py
5
train.py
@ -156,6 +156,9 @@ def main():
|
||||
global_pool=args.gp,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
@ -178,7 +181,7 @@ def main():
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
if has_apex and args.amp:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O3')
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||
use_amp = True
|
||||
print('AMP enabled')
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user