Merge remote-tracking branch 'upstream/master'
commit
0c4b4b8817
|
@ -18,7 +18,7 @@ def detect(save_img=False):
|
|||
|
||||
# Load model
|
||||
google_utils.attempt_download(weights)
|
||||
model = torch.load(weights, map_location=device)['model']
|
||||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32
|
||||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
|
||||
# model.fuse()
|
||||
model.to(device).eval()
|
||||
|
|
|
@ -32,8 +32,8 @@ def create(name, pretrained, channels, classes):
|
|||
if pretrained:
|
||||
ckpt = '%s.pt' % name # checkpoint filename
|
||||
google_utils.attempt_download(ckpt) # download if not found locally
|
||||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].state_dict()
|
||||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].numel() == v.numel()} # filter
|
||||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
|
||||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
||||
model.load_state_dict(state_dict, strict=False) # load
|
||||
return model
|
||||
|
||||
|
|
18
test.py
18
test.py
|
@ -23,6 +23,7 @@ def test(data,
|
|||
verbose=False):
|
||||
# Initialize/load model and set device
|
||||
if model is None:
|
||||
training = False
|
||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||
half = device.type != 'cpu' # half precision only supported on CUDA
|
||||
|
||||
|
@ -32,9 +33,9 @@ def test(data,
|
|||
|
||||
# Load model
|
||||
google_utils.attempt_download(weights)
|
||||
model = torch.load(weights, map_location=device)['model']
|
||||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32
|
||||
torch_utils.model_info(model)
|
||||
# model.fuse()
|
||||
model.fuse()
|
||||
model.to(device)
|
||||
if half:
|
||||
model.half() # to FP16
|
||||
|
@ -42,11 +43,12 @@ def test(data,
|
|||
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
training = False
|
||||
else: # called by train.py
|
||||
device = next(model.parameters()).device # get model device
|
||||
half = False
|
||||
training = True
|
||||
device = next(model.parameters()).device # get model device
|
||||
half = device.type != 'cpu' # half precision only supported on CUDA
|
||||
if half:
|
||||
model.half() # to FP16
|
||||
|
||||
# Configure
|
||||
model.eval()
|
||||
|
@ -69,7 +71,7 @@ def test(data,
|
|||
batch_size,
|
||||
rect=True, # rectangular inference
|
||||
single_cls=opt.single_cls, # single class mode
|
||||
pad=0.0 if fast else 0.5) # padding
|
||||
pad=0.5) # padding
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||
dataloader = DataLoader(dataset,
|
||||
|
@ -102,7 +104,7 @@ def test(data,
|
|||
|
||||
# Compute loss
|
||||
if training: # if model has loss hyperparameters
|
||||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
||||
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
|
||||
|
||||
# Run NMS
|
||||
t = torch_utils.time_synchronized()
|
||||
|
@ -255,7 +257,7 @@ if __name__ == '__main__':
|
|||
opt = parser.parse_args()
|
||||
opt.img_size = check_img_size(opt.img_size)
|
||||
opt.save_json = opt.save_json or opt.data.endswith('coco.yaml')
|
||||
opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file
|
||||
opt.data = check_file(opt.data) # check file
|
||||
print(opt)
|
||||
|
||||
# task = 'val', 'test', 'study'
|
||||
|
|
10
train.py
10
train.py
|
@ -112,8 +112,8 @@ def train(hyp):
|
|||
|
||||
# load model
|
||||
try:
|
||||
ckpt['model'] = \
|
||||
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
|
||||
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
|
||||
if model.state_dict()[k].shape == v.shape} # to FP32, filter
|
||||
model.load_state_dict(ckpt['model'], strict=False)
|
||||
except KeyError as e:
|
||||
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
||||
|
@ -363,6 +363,7 @@ def train(hyp):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_git_status()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epochs', type=int, default=300)
|
||||
parser.add_argument('--batch-size', type=int, default=16)
|
||||
|
@ -384,12 +385,11 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
opt = parser.parse_args()
|
||||
opt.weights = last if opt.resume else opt.weights
|
||||
opt.cfg = glob.glob('./**/' + opt.cfg, recursive=True)[0] # find file
|
||||
opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file
|
||||
opt.cfg = check_file(opt.cfg) # check file
|
||||
opt.data = check_file(opt.data) # check file
|
||||
print(opt)
|
||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
||||
# check_git_status()
|
||||
if device.type == 'cpu':
|
||||
mixed_precision = False
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
|
@ -25,10 +25,15 @@ def attempt_download(weights):
|
|||
if file in d:
|
||||
r = gdrive_download(id=d[file], name=weights)
|
||||
|
||||
# Error check
|
||||
if not (r == 0 and os.path.exists(weights) and os.path.getsize(weights) > 1E6): # weights exist and > 1MB
|
||||
os.system('rm ' + weights) # remove partial downloads
|
||||
raise Exception(msg)
|
||||
os.remove(weights) if os.path.exists(weights) else None # remove partial downloads
|
||||
s = "curl -L -o %s 'https://storage.googleapis.com/ultralytics/yolov5/ckpt/%s'" % (weights, file)
|
||||
r = os.system(s) # execute, capture return values
|
||||
|
||||
# Error check
|
||||
if not (r == 0 and os.path.exists(weights) and os.path.getsize(weights) > 1E6): # weights exist and > 1MB
|
||||
os.remove(weights) if os.path.exists(weights) else None # remove partial downloads
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
def gdrive_download(id='1HaXkef9z6y5l4vUnCYgdmEAj61c6bfWO', name='coco.zip'):
|
||||
|
|
|
@ -64,6 +64,16 @@ def check_best_possible_recall(dataset, anchors, thr):
|
|||
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
|
||||
|
||||
|
||||
def check_file(file):
|
||||
# Searches for file if not found locally
|
||||
if os.path.isfile(file):
|
||||
return file
|
||||
else:
|
||||
files = glob.glob('./**/' + file, recursive=True) # find file
|
||||
assert len(files), 'File Not Found: %s' % file # assert file was found
|
||||
return files[0] # return first file if multiple found
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
# Returns x evenly divisble by divisor
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
@ -518,7 +528,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
|
|||
fast |= conf_thres > 0.001 # fast mode
|
||||
if fast:
|
||||
merge = False
|
||||
multi_label = False
|
||||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
else:
|
||||
merge = True # merge for best mAP (adds 0.5ms/img)
|
||||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
|
Loading…
Reference in New Issue