Refactor for simplification (#9054)
* Refactor for simplification * cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9055/head
parent
f258cf8b37
commit
c725511bfc
|
@ -46,7 +46,7 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||||
except Exception as e: # url2
|
except Exception as e: # url2
|
||||||
file.unlink(missing_ok=True) # remove partial downloads
|
file.unlink(missing_ok=True) # remove partial downloads
|
||||||
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
||||||
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
||||||
finally:
|
finally:
|
||||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
if not file.exists() or file.stat().st_size < min_bytes: # check
|
||||||
file.unlink(missing_ok=True) # remove partial downloads
|
file.unlink(missing_ok=True) # remove partial downloads
|
||||||
|
|
|
@ -582,7 +582,7 @@ def url2file(url):
|
||||||
|
|
||||||
|
|
||||||
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
||||||
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
|
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||||||
def download_one(url, dir):
|
def download_one(url, dir):
|
||||||
# Download 1 file
|
# Download 1 file
|
||||||
success = True
|
success = True
|
||||||
|
@ -594,7 +594,8 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
||||||
for i in range(retry + 1):
|
for i in range(retry + 1):
|
||||||
if curl:
|
if curl:
|
||||||
s = 'sS' if threads > 1 else '' # silent
|
s = 'sS' if threads > 1 else '' # silent
|
||||||
r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
r = os.system(
|
||||||
|
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
||||||
success = r == 0
|
success = r == 0
|
||||||
else:
|
else:
|
||||||
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
||||||
|
|
|
@ -141,7 +141,7 @@ class ConfusionMatrix:
|
||||||
"""
|
"""
|
||||||
if detections is None:
|
if detections is None:
|
||||||
gt_classes = labels.int()
|
gt_classes = labels.int()
|
||||||
for i, gc in enumerate(gt_classes):
|
for gc in gt_classes:
|
||||||
self.matrix[self.nc, gc] += 1 # background FN
|
self.matrix[self.nc, gc] += 1 # background FN
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Plotting utils
|
Plotting utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
@ -180,8 +181,7 @@ def output_to_target(output):
|
||||||
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
||||||
targets = []
|
targets = []
|
||||||
for i, o in enumerate(output):
|
for i, o in enumerate(output):
|
||||||
for *box, conf, cls in o.cpu().numpy():
|
targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
|
||||||
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
|
|
||||||
return np.array(targets)
|
return np.array(targets)
|
||||||
|
|
||||||
|
|
||||||
|
@ -357,10 +357,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
|
||||||
matplotlib.use('svg') # faster
|
matplotlib.use('svg') # faster
|
||||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
||||||
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||||
try: # color histogram bars by class
|
with contextlib.suppress(Exception): # color histogram bars by class
|
||||||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
ax[0].set_ylabel('instances')
|
ax[0].set_ylabel('instances')
|
||||||
if 0 < len(names) < 30:
|
if 0 < len(names) < 30:
|
||||||
ax[0].set_xticks(range(len(names)))
|
ax[0].set_xticks(range(len(names)))
|
||||||
|
|
|
@ -45,11 +45,10 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||||
def smartCrossEntropyLoss(label_smoothing=0.0):
|
def smartCrossEntropyLoss(label_smoothing=0.0):
|
||||||
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
||||||
if check_version(torch.__version__, '1.10.0'):
|
if check_version(torch.__version__, '1.10.0'):
|
||||||
return nn.CrossEntropyLoss(label_smoothing=label_smoothing) # loss function
|
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||||
else:
|
if label_smoothing > 0:
|
||||||
if label_smoothing > 0:
|
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
|
||||||
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
|
return nn.CrossEntropyLoss()
|
||||||
return nn.CrossEntropyLoss() # loss function
|
|
||||||
|
|
||||||
|
|
||||||
def smart_DDP(model):
|
def smart_DDP(model):
|
||||||
|
@ -118,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
|
||||||
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
||||||
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
||||||
|
|
||||||
if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||||
n = len(devices) # device count
|
n = len(devices) # device count
|
||||||
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
||||||
|
|
Loading…
Reference in New Issue