Update export.py, yolo.py `sys.path.append()` (#3579)
parent
095197bd4a
commit
53ed872c28
|
@ -9,13 +9,15 @@ import sys
|
|||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
import models
|
||||
FILE = Path(__file__).absolute()
|
||||
sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
|
||||
|
||||
from models.common import Conv
|
||||
from models.yolo import Detect
|
||||
from models.experimental import attempt_load
|
||||
from utils.activations import Hardswish, SiLU
|
||||
from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
|
||||
|
@ -56,12 +58,12 @@ def export(weights='./yolov5s.pt', # weights path
|
|||
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
|
||||
for k, m in model.named_modules():
|
||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
||||
if isinstance(m, models.common.Conv): # assign export-friendly activations
|
||||
if isinstance(m, Conv): # assign export-friendly activations
|
||||
if isinstance(m.act, nn.Hardswish):
|
||||
m.act = Hardswish()
|
||||
elif isinstance(m.act, nn.SiLU):
|
||||
m.act = SiLU()
|
||||
elif isinstance(m, models.yolo.Detect):
|
||||
elif isinstance(m, Detect):
|
||||
m.inplace = inplace
|
||||
m.onnx_dynamic = dynamic
|
||||
# m.forward = m.forward_export # assign forward (optional)
|
||||
|
|
|
@ -10,8 +10,8 @@ import sys
|
|||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
|
||||
logger = logging.getLogger(__name__)
|
||||
FILE = Path(__file__).absolute()
|
||||
sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
|
||||
|
||||
from models.common import *
|
||||
from models.experimental import *
|
||||
|
@ -25,6 +25,8 @@ try:
|
|||
except ImportError:
|
||||
thop = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Detect(nn.Module):
|
||||
stride = None # strides computed during build
|
||||
|
|
Loading…
Reference in New Issue