Update export.py with --train mode argument (#3066)
parent
f2de1ad2aa
commit
e97d129db4
|
@ -29,6 +29,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
||||
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
|
||||
parser.add_argument('--train', action='store_true', help='model.train() mode')
|
||||
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
|
||||
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
|
||||
opt = parser.parse_args()
|
||||
|
@ -53,6 +54,8 @@ if __name__ == '__main__':
|
|||
# Update model
|
||||
if opt.half:
|
||||
img, model = img.half(), model.half() # to FP16
|
||||
if opt.train:
|
||||
model.train() # training mode (no grid construction in Detect layer)
|
||||
for k, m in model.named_modules():
|
||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
||||
if isinstance(m, models.common.Conv): # assign export-friendly activations
|
||||
|
|
Loading…
Reference in New Issue