diff --git a/models/export.py b/models/export.py index 00e3b2a4f..e870deca8 100644 --- a/models/export.py +++ b/models/export.py @@ -7,6 +7,7 @@ Usage: import argparse import torch +import torch.nn as nn from models.common import Conv from models.experimental import attempt_load @@ -32,7 +33,7 @@ if __name__ == '__main__': # Update model for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability - if isinstance(m, Conv): + if isinstance(m, Conv) and isinstance(m.act, nn.Hardswish): m.act = Hardswish() # assign activation # if isinstance(m, Detect): # m.forward = m.forward_export # assign forward (optional)