classifier, export, torch seed updates
parent
c5d2331897
commit
883924d9dc
|
@ -6,6 +6,7 @@ Usage:
|
|||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
||||
|
||||
|
@ -15,7 +16,7 @@ import torch.nn as nn
|
|||
import models
|
||||
from models.experimental import attempt_load
|
||||
from utils.activations import Hardswish
|
||||
from utils.general import set_logging
|
||||
from utils.general import set_logging, check_img_size
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -26,16 +27,22 @@ if __name__ == '__main__':
|
|||
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
|
||||
print(opt)
|
||||
set_logging()
|
||||
t = time.time()
|
||||
|
||||
# Input
|
||||
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
|
||||
|
||||
# Load PyTorch model
|
||||
model = attempt_load(opt.weights, map_location=torch.device('cpu')) # load FP32 model
|
||||
labels = model.names
|
||||
|
||||
# Checks
|
||||
gs = int(max(model.stride)) # grid size (max stride)
|
||||
opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
|
||||
|
||||
# Update model
|
||||
for k, m in model.named_modules():
|
||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
||||
if isinstance(m, models.common.Conv) and isinstance(m.act, nn.Hardswish):
|
||||
m.act = Hardswish() # assign activation
|
||||
# if isinstance(m, models.yolo.Detect):
|
||||
|
@ -76,7 +83,7 @@ if __name__ == '__main__':
|
|||
|
||||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
|
||||
# convert model from torchscript and apply pixel scaling as per detect.py
|
||||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
|
||||
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
|
||||
f = opt.weights.replace('.pt', '.mlmodel') # filename
|
||||
model.save(f)
|
||||
print('CoreML export success, saved as %s' % f)
|
||||
|
@ -84,4 +91,4 @@ if __name__ == '__main__':
|
|||
print('CoreML export failure: %s' % e)
|
||||
|
||||
# Finish
|
||||
print('\nExport complete. Visualize with https://github.com/lutzroeder/netron.')
|
||||
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
|
||||
|
|
|
@ -23,8 +23,7 @@ from scipy.signal import butter, filtfilt
|
|||
from tqdm import tqdm
|
||||
|
||||
from utils.google_utils import gsutil_getsize
|
||||
from utils.torch_utils import init_seeds as init_torch_seeds
|
||||
from utils.torch_utils import is_parallel
|
||||
from utils.torch_utils import is_parallel, init_torch_seeds
|
||||
|
||||
# Set printoptions
|
||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||
|
@ -56,7 +55,7 @@ def set_logging(rank=-1):
|
|||
def init_seeds(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
init_torch_seeds(seed=seed)
|
||||
init_torch_seeds(seed)
|
||||
|
||||
|
||||
def get_latest_run(search_dir='./runs'):
|
||||
|
|
|
@ -8,12 +8,11 @@ import torch
|
|||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
def init_torch_seeds(seed=0):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
||||
|
@ -152,16 +151,15 @@ def model_info(model, verbose=False):
|
|||
|
||||
def load_classifier(name='resnet101', n=2):
|
||||
# Loads a pretrained model reshaped to n-class output
|
||||
model = models.__dict__[name](pretrained=True)
|
||||
import torchvision
|
||||
model = torchvision.models.__dict__[name](pretrained=True)
|
||||
|
||||
# Display model properties
|
||||
input_size = [3, 224, 224]
|
||||
input_space = 'RGB'
|
||||
input_range = [0, 1]
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
for x in ['input_size', 'input_space', 'input_range', 'mean', 'std']:
|
||||
print(x + ' =', eval(x))
|
||||
# ResNet model properties
|
||||
# input_size = [3, 224, 224]
|
||||
# input_space = 'RGB'
|
||||
# input_range = [0, 1]
|
||||
# mean = [0.485, 0.456, 0.406]
|
||||
# std = [0.229, 0.224, 0.225]
|
||||
|
||||
# Reshape output to n classes
|
||||
filters = model.fc.weight.shape[1]
|
||||
|
|
Loading…
Reference in New Issue