Update general.py (#9454)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9458/head
parent
4a4308001c
commit
6a9fffd19a
|
@ -469,8 +469,7 @@ def check_dataset(data, autodownload=True):
|
|||
|
||||
# Read yaml (optional)
|
||||
if isinstance(data, (str, Path)):
|
||||
with open(data, errors='ignore') as f:
|
||||
data = yaml.safe_load(f) # dictionary
|
||||
data = yaml_load(data) # dictionary
|
||||
|
||||
# Checks
|
||||
for k in 'train', 'val', 'names':
|
||||
|
@ -485,7 +484,13 @@ def check_dataset(data, autodownload=True):
|
|||
path = (ROOT / path).resolve()
|
||||
for k in 'train', 'val', 'test':
|
||||
if data.get(k): # prepend path
|
||||
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
|
||||
if isinstance(data[k], str):
|
||||
x = (path / data[k]).resolve()
|
||||
if not x.exists() and data[k].startswith('../'):
|
||||
x = (path / data[k][3:]).resolve()
|
||||
data[k] = str(x)
|
||||
else:
|
||||
data[k] = [str((path / x).resolve()) for x in data[k]]
|
||||
|
||||
# Parse yaml
|
||||
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
||||
|
@ -496,13 +501,12 @@ def check_dataset(data, autodownload=True):
|
|||
if not s or not autodownload:
|
||||
raise Exception('Dataset not found ❌')
|
||||
t = time.time()
|
||||
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
f = Path(s).name # filename
|
||||
LOGGER.info(f'Downloading {s} to {f}...')
|
||||
torch.hub.download_url_to_file(s, f)
|
||||
Path(root).mkdir(parents=True, exist_ok=True) # create root
|
||||
ZipFile(f).extractall(path=root) # unzip
|
||||
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
|
||||
ZipFile(f).extractall(path=DATASETS_DIR) # unzip
|
||||
Path(f).unlink() # remove zip
|
||||
r = None # success
|
||||
elif s.startswith('bash '): # bash script
|
||||
|
@ -511,7 +515,7 @@ def check_dataset(data, autodownload=True):
|
|||
else: # python script
|
||||
r = exec(s, {'yaml': data}) # return None
|
||||
dt = f'({round(time.time() - t, 1)}s)'
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}")
|
||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
||||
return data # dictionary
|
||||
|
|
Loading…
Reference in New Issue