mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Disable use of timm nn.Linear wrapper since AMP autocast + torchscript use appears fixed
This commit is contained in:
parent
58ffa2bfb7
commit
834a9ec721
@ -6,7 +6,6 @@ from torch import nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from .linear import Linear
|
|
||||||
|
|
||||||
|
|
||||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||||
@ -26,8 +25,7 @@ def _create_fc(num_features, num_classes, use_conv=False):
|
|||||||
elif use_conv:
|
elif use_conv:
|
||||||
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
||||||
else:
|
else:
|
||||||
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
|
fc = nn.Linear(num_features, num_classes, bias=True)
|
||||||
fc = Linear(num_features, num_classes, bias=True)
|
|
||||||
return fc
|
return fc
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user