Update export.py, yolo.py `sys.path.append()` (#3579)

pull/3585/head
Glenn Jocher 2021-06-10 15:35:22 +02:00 committed by GitHub
parent 095197bd4a
commit 53ed872c28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 7 deletions

View File

@ -9,13 +9,15 @@ import sys
import time import time
from pathlib import Path from pathlib import Path
sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile from torch.utils.mobile_optimizer import optimize_for_mobile
import models FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
from models.common import Conv
from models.yolo import Detect
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU from utils.activations import Hardswish, SiLU
from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
@ -56,12 +58,12 @@ def export(weights='./yolov5s.pt', # weights path
model.train() if train else model.eval() # training mode = no Detect() layer grid construction model.train() if train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules(): for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish): if isinstance(m.act, nn.Hardswish):
m.act = Hardswish() m.act = Hardswish()
elif isinstance(m.act, nn.SiLU): elif isinstance(m.act, nn.SiLU):
m.act = SiLU() m.act = SiLU()
elif isinstance(m, models.yolo.Detect): elif isinstance(m, Detect):
m.inplace = inplace m.inplace = inplace
m.onnx_dynamic = dynamic m.onnx_dynamic = dynamic
# m.forward = m.forward_export # assign forward (optional) # m.forward = m.forward_export # assign forward (optional)

View File

@ -10,8 +10,8 @@ import sys
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories FILE = Path(__file__).absolute()
logger = logging.getLogger(__name__) sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
from models.common import * from models.common import *
from models.experimental import * from models.experimental import *
@ -25,6 +25,8 @@ try:
except ImportError: except ImportError:
thop = None thop = None
logger = logging.getLogger(__name__)
class Detect(nn.Module): class Detect(nn.Module):
stride = None # strides computed during build stride = None # strides computed during build