Update general.py (#9454)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/9458/head
Glenn Jocher 2022-09-17 15:42:24 +02:00 committed by GitHub
parent 4a4308001c
commit 6a9fffd19a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 7 deletions

View File

@ -469,8 +469,7 @@ def check_dataset(data, autodownload=True):
# Read yaml (optional)
if isinstance(data, (str, Path)):
with open(data, errors='ignore') as f:
data = yaml.safe_load(f) # dictionary
data = yaml_load(data) # dictionary
# Checks
for k in 'train', 'val', 'names':
@ -485,7 +484,13 @@ def check_dataset(data, autodownload=True):
path = (ROOT / path).resolve()
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
if isinstance(data[k], str):
x = (path / data[k]).resolve()
if not x.exists() and data[k].startswith('../'):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse yaml
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
@ -496,13 +501,12 @@ def check_dataset(data, autodownload=True):
if not s or not autodownload:
raise Exception('Dataset not found ❌')
t = time.time()
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
LOGGER.info(f'Downloading {s} to {f}...')
torch.hub.download_url_to_file(s, f)
Path(root).mkdir(parents=True, exist_ok=True) # create root
ZipFile(f).extractall(path=root) # unzip
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
ZipFile(f).extractall(path=DATASETS_DIR) # unzip
Path(f).unlink() # remove zip
r = None # success
elif s.startswith('bash '): # bash script
@ -511,7 +515,7 @@ def check_dataset(data, autodownload=True):
else: # python script
r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt}"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
LOGGER.info(f"Dataset download {s}")
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
return data # dictionary