Update export.py with v3.0 Hardswish() support (#831)
parent
4d7f222f73
commit
4fb8cb353f
|
@ -7,6 +7,7 @@ Usage:
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.common import Conv
|
||||
from models.experimental import attempt_load
|
||||
|
@ -32,7 +33,7 @@ if __name__ == '__main__':
|
|||
# Update model
|
||||
for k, m in model.named_modules():
|
||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
||||
if isinstance(m, Conv):
|
||||
if isinstance(m, Conv) and isinstance(m.act, nn.Hardswish):
|
||||
m.act = Hardswish() # assign activation
|
||||
# if isinstance(m, Detect):
|
||||
# m.forward = m.forward_export # assign forward (optional)
|
||||
|
|
Loading…
Reference in New Issue