mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Work around casting issue with combination of native torch AMP and torchscript for Linear layers
This commit is contained in:
parent
5f4b6076d8
commit
460eba7f24
@ -6,6 +6,7 @@ from torch import nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
|
from .linear import Linear
|
||||||
|
|
||||||
|
|
||||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||||
@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
|
|||||||
elif use_conv:
|
elif use_conv:
|
||||||
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
|
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
|
||||||
else:
|
else:
|
||||||
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
|
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
|
||||||
|
fc = Linear(num_pooled_features, num_classes, bias=True)
|
||||||
return global_pool, fc
|
return global_pool, fc
|
||||||
|
|
||||||
|
|
||||||
|
18
timm/models/layers/linear.py
Normal file
18
timm/models/layers/linear.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
""" Linear layer (alternate definition)
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
||||||
|
|
||||||
|
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
|
||||||
|
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
|
||||||
|
"""
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype))
|
||||||
|
else:
|
||||||
|
return F.linear(input, self.weight, self.bias)
|
1
train.py
1
train.py
@ -367,7 +367,6 @@ def main():
|
|||||||
if args.torchscript:
|
if args.torchscript:
|
||||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||||
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
optimizer = create_optimizer(args, model)
|
optimizer = create_optimizer(args, model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user