Refactor for simplification (#9054)

* Refactor for simplification

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/9055/head
Glenn Jocher 2022-08-21 01:34:03 +02:00 committed by GitHub
parent f258cf8b37
commit c725511bfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 15 deletions

View File

@ -46,7 +46,7 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
except Exception as e: # url2 except Exception as e: # url2
file.unlink(missing_ok=True) # remove partial downloads file.unlink(missing_ok=True) # remove partial downloads
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...') LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally: finally:
if not file.exists() or file.stat().st_size < min_bytes: # check if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads file.unlink(missing_ok=True) # remove partial downloads

View File

@ -582,7 +582,7 @@ def url2file(url):
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3): def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
# Multi-threaded file download and unzip function, used in data.yaml for autodownload # Multithreaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir): def download_one(url, dir):
# Download 1 file # Download 1 file
success = True success = True
@ -594,7 +594,8 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
for i in range(retry + 1): for i in range(retry + 1):
if curl: if curl:
s = 'sS' if threads > 1 else '' # silent s = 'sS' if threads > 1 else '' # silent
r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue r = os.system(
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
success = r == 0 success = r == 0
else: else:
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download

View File

@ -141,7 +141,7 @@ class ConfusionMatrix:
""" """
if detections is None: if detections is None:
gt_classes = labels.int() gt_classes = labels.int()
for i, gc in enumerate(gt_classes): for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN self.matrix[self.nc, gc] += 1 # background FN
return return

View File

@ -3,6 +3,7 @@
Plotting utils Plotting utils
""" """
import contextlib
import math import math
import os import os
from copy import copy from copy import copy
@ -180,8 +181,7 @@ def output_to_target(output):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
targets = [] targets = []
for i, o in enumerate(output): for i, o in enumerate(output):
for *box, conf, cls in o.cpu().numpy(): targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
return np.array(targets) return np.array(targets)
@ -357,10 +357,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
matplotlib.use('svg') # faster matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
try: # color histogram bars by class with contextlib.suppress(Exception): # color histogram bars by class
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195 [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
except Exception:
pass
ax[0].set_ylabel('instances') ax[0].set_ylabel('instances')
if 0 < len(names) < 30: if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names))) ax[0].set_xticks(range(len(names)))

View File

@ -45,11 +45,10 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
def smartCrossEntropyLoss(label_smoothing=0.0): def smartCrossEntropyLoss(label_smoothing=0.0):
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0 # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
if check_version(torch.__version__, '1.10.0'): if check_version(torch.__version__, '1.10.0'):
return nn.CrossEntropyLoss(label_smoothing=label_smoothing) # loss function return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
else: if label_smoothing > 0:
if label_smoothing > 0: LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0') return nn.CrossEntropyLoss()
return nn.CrossEntropyLoss() # loss function
def smart_DDP(model): def smart_DDP(model):
@ -118,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)" f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count n = len(devices) # device count
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count if n > 1 and batch_size > 0: # check batch_size is divisible by device_count