Dataset autodownload feature addition (#685)
* initial commit * move download scripts into data/scripts * new check_dataset() function in general.py * move check_dataset() out of with context * Update general.py * DDP update * Update general.pypull/688/head
parent
3d8ed0a76b
commit
41523e2c91
|
@ -1,5 +1,4 @@
|
|||
# COCO 2017 dataset http://cocodataset.org
|
||||
# Download command: bash yolov5/data/get_coco2017.sh
|
||||
# Train command: python train.py --data coco.yaml
|
||||
# Default dataset location is next to /yolov5:
|
||||
# /parent_folder
|
||||
|
@ -7,6 +6,9 @@
|
|||
# /yolov5
|
||||
|
||||
|
||||
# download command/URL (optional)
|
||||
download: bash data/scripts/get_coco.sh
|
||||
|
||||
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
||||
train: ../coco/train2017.txt # 118287 images
|
||||
val: ../coco/val2017.txt # 5000 images
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# COCO 2017 dataset http://cocodataset.org - first 128 training images
|
||||
# Download command: python -c "from yolov5.utils.google_utils import *; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', 'coco128.zip')"
|
||||
# Train command: python train.py --data coco128.yaml
|
||||
# Default dataset location is next to /yolov5:
|
||||
# /parent_folder
|
||||
|
@ -7,6 +6,9 @@
|
|||
# /yolov5
|
||||
|
||||
|
||||
# download command/URL (optional)
|
||||
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
|
||||
|
||||
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
||||
train: ../coco128/images/train2017/ # 128 images
|
||||
val: ../coco128/images/train2017/ # 128 images
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
#!/bin/bash
|
||||
# COCO 2017 dataset http://cocodataset.org
|
||||
# Download command: bash yolov5/data/get_coco2017.sh
|
||||
# Train command: python train.py --data coco.yaml
|
||||
# Default dataset location is next to /yolov5:
|
||||
# /parent_folder
|
||||
# /coco
|
||||
# /yolov5
|
||||
|
||||
|
||||
# Download labels from Google Drive, accepting presented query
|
||||
filename="coco2017labels.zip"
|
||||
fileid="1cXZR_ckHki6nddOmcysCuuJFM--T-Q6L"
|
||||
curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null
|
||||
curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename}
|
||||
rm ./cookie
|
||||
|
||||
# Unzip labels
|
||||
unzip -q ${filename} # for coco.zip
|
||||
# tar -xzf ${filename} # for coco.tar.gz
|
||||
rm ${filename}
|
||||
|
||||
# Download and unzip images
|
||||
cd coco/images
|
||||
f="train2017.zip" && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f && rm $f # 19G, 118k images
|
||||
f="val2017.zip" && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f && rm $f # 1G, 5k images
|
||||
# f="test2017.zip" && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f && rm $f # 7G, 41k images
|
||||
|
||||
# cd out
|
||||
cd ../..
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
# COCO 2017 dataset http://cocodataset.org
|
||||
# Download command: bash data/scripts/get_coco.sh
|
||||
# Train command: python train.py --data coco.yaml
|
||||
# Default dataset location is next to /yolov5:
|
||||
# /parent_folder
|
||||
# /coco
|
||||
# /yolov5
|
||||
|
||||
# Download/unzip labels
|
||||
echo 'Downloading COCO 2017 labels ...'
|
||||
d='../' # unzip directory
|
||||
f='coco2017labels.zip' && curl -L https://github.com/ultralytics/yolov5/releases/download/v1.0/$f -o $f
|
||||
unzip -q $f -d $d && rm $f
|
||||
|
||||
# Download/unzip images
|
||||
echo 'Downloading COCO 2017 images ...'
|
||||
d='../coco/images' # unzip directory
|
||||
f='train2017.zip' && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d $d && rm $f # 19G, 118k images
|
||||
f='val2017.zip' && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d $d && rm $f # 1G, 5k images
|
||||
# f='test2017.zip' && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d $d && rm $f # 7G, 41k images
|
|
@ -1,33 +1,32 @@
|
|||
#!/bin/bash
|
||||
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
|
||||
# Download command: bash ./data/get_voc.sh
|
||||
# Download command: bash data/scripts/get_voc.sh
|
||||
# Train command: python train.py --data voc.yaml
|
||||
# Default dataset location is next to /yolov5:
|
||||
# /parent_folder
|
||||
# /VOC
|
||||
# /yolov5
|
||||
|
||||
|
||||
start=`date +%s`
|
||||
start=$(date +%s)
|
||||
|
||||
# handle optional download dir
|
||||
if [ -z "$1" ]
|
||||
then
|
||||
# navigate to ~/tmp
|
||||
echo "navigating to ../tmp/ ..."
|
||||
mkdir -p ../tmp
|
||||
cd ../tmp/
|
||||
else
|
||||
# check if is valid directory
|
||||
if [ ! -d $1 ]; then
|
||||
echo $1 "is not a valid directory"
|
||||
exit 0
|
||||
fi
|
||||
echo "navigating to" $1 "..."
|
||||
cd $1
|
||||
if [ -z "$1" ]; then
|
||||
# navigate to ~/tmp
|
||||
echo "navigating to ../tmp/ ..."
|
||||
mkdir -p ../tmp
|
||||
cd ../tmp/
|
||||
else
|
||||
# check if is valid directory
|
||||
if [ ! -d $1 ]; then
|
||||
echo $1 "is not a valid directory"
|
||||
exit 0
|
||||
fi
|
||||
echo "navigating to" $1 "..."
|
||||
cd $1
|
||||
fi
|
||||
|
||||
echo "Downloading VOC2007 trainval ..."
|
||||
# Download the data.
|
||||
# Download data
|
||||
curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
|
||||
echo "Downloading VOC2007 test data ..."
|
||||
curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
|
||||
|
@ -42,44 +41,42 @@ echo "removing tars ..."
|
|||
rm VOCtrainval_06-Nov-2007.tar
|
||||
rm VOCtest_06-Nov-2007.tar
|
||||
|
||||
end=`date +%s`
|
||||
runtime=$((end-start))
|
||||
end=$(date +%s)
|
||||
runtime=$((end - start))
|
||||
|
||||
echo "Completed in" $runtime "seconds"
|
||||
|
||||
start=`date +%s`
|
||||
start=$(date +%s)
|
||||
|
||||
# handle optional download dir
|
||||
if [ -z "$1" ]
|
||||
then
|
||||
# navigate to ~/tmp
|
||||
echo "navigating to ../tmp/ ..."
|
||||
mkdir -p ../tmp
|
||||
cd ../tmp/
|
||||
else
|
||||
# check if is valid directory
|
||||
if [ ! -d $1 ]; then
|
||||
echo $1 "is not a valid directory"
|
||||
exit 0
|
||||
fi
|
||||
echo "navigating to" $1 "..."
|
||||
cd $1
|
||||
if [ -z "$1" ]; then
|
||||
# navigate to ~/tmp
|
||||
echo "navigating to ../tmp/ ..."
|
||||
mkdir -p ../tmp
|
||||
cd ../tmp/
|
||||
else
|
||||
# check if is valid directory
|
||||
if [ ! -d $1 ]; then
|
||||
echo $1 "is not a valid directory"
|
||||
exit 0
|
||||
fi
|
||||
echo "navigating to" $1 "..."
|
||||
cd $1
|
||||
fi
|
||||
|
||||
echo "Downloading VOC2012 trainval ..."
|
||||
# Download the data.
|
||||
# Download data
|
||||
curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
|
||||
echo "Done downloading."
|
||||
|
||||
|
||||
# Extract data
|
||||
echo "Extracting trainval ..."
|
||||
tar -xf VOCtrainval_11-May-2012.tar
|
||||
echo "removing tar ..."
|
||||
rm VOCtrainval_11-May-2012.tar
|
||||
|
||||
end=`date +%s`
|
||||
runtime=$((end-start))
|
||||
end=$(date +%s)
|
||||
runtime=$((end - start))
|
||||
|
||||
echo "Completed in" $runtime "seconds"
|
||||
|
||||
|
@ -144,8 +141,8 @@ for year, image_set in sets:
|
|||
|
||||
END
|
||||
|
||||
cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt > train.txt
|
||||
cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt > train.all.txt
|
||||
cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt >train.txt
|
||||
cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt >train.all.txt
|
||||
|
||||
python3 - "$@" <<END
|
||||
|
||||
|
@ -211,5 +208,5 @@ for line in lines:
|
|||
|
||||
END
|
||||
|
||||
rm -rf ../tmp # remove temporary directory
|
||||
rm -rf ../tmp # remove temporary directory
|
||||
echo "VOC download done."
|
|
@ -1,5 +1,4 @@
|
|||
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
|
||||
# Download command: bash ./data/get_voc.sh
|
||||
# Train command: python train.py --data voc.yaml
|
||||
# Default dataset location is next to /yolov5:
|
||||
# /parent_folder
|
||||
|
@ -7,6 +6,9 @@
|
|||
# /yolov5
|
||||
|
||||
|
||||
# download command/URL (optional)
|
||||
download: bash data/scripts/get_voc.sh
|
||||
|
||||
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
||||
train: ../VOC/images/train/ # 16551 images
|
||||
val: ../VOC/images/val/ # 4952 images
|
||||
|
|
3
test.py
3
test.py
|
@ -13,7 +13,7 @@ from tqdm import tqdm
|
|||
from models.experimental import attempt_load
|
||||
from utils.datasets import create_dataloader
|
||||
from utils.general import (
|
||||
coco80_to_coco91_class, check_file, check_img_size, compute_loss, non_max_suppression,
|
||||
coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression,
|
||||
scale_coords, xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class)
|
||||
from utils.torch_utils import select_device, time_synchronized
|
||||
|
||||
|
@ -68,6 +68,7 @@ def test(data,
|
|||
model.eval()
|
||||
with open(data) as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
check_dataset(data) # check
|
||||
nc = 1 if single_cls else int(data['nc']) # number of classes
|
||||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
|
||||
niou = iouv.numel()
|
||||
|
|
8
train.py
8
train.py
|
@ -21,9 +21,9 @@ import test # import test.py to get mAP after each epoch
|
|||
from models.yolo import Model
|
||||
from utils.datasets import create_dataloader
|
||||
from utils.general import (
|
||||
check_img_size, torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors,
|
||||
labels_to_image_weights, compute_loss, plot_images, fitness, strip_optimizer, plot_results,
|
||||
get_latest_run, check_git_status, check_file, increment_dir, print_mutation, plot_evolution)
|
||||
torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
|
||||
compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
|
||||
check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution)
|
||||
from utils.google_utils import attempt_download
|
||||
from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts
|
||||
|
||||
|
@ -51,6 +51,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
init_seeds(2 + rank)
|
||||
with open(opt.data) as f:
|
||||
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
with torch_distributed_zero_first(rank):
|
||||
check_dataset(data_dict) # check
|
||||
train_path = data_dict['train']
|
||||
test_path = data_dict['val']
|
||||
nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
|
||||
|
|
|
@ -128,6 +128,25 @@ def check_file(file):
|
|||
return files[0] # return first file if multiple found
|
||||
|
||||
|
||||
def check_dataset(dict):
|
||||
# Download dataset if not found
|
||||
train, val = os.path.abspath(dict['train']), os.path.abspath(dict['val']) # data paths
|
||||
if not (os.path.exists(train) and os.path.exists(val)):
|
||||
print('\nWARNING: Dataset not found, nonexistant paths: %s' % [train, val])
|
||||
if 'download' in dict:
|
||||
s = dict['download']
|
||||
print('Attempting autodownload from: %s' % s)
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
f = Path(s).name # filename
|
||||
torch.hub.download_url_to_file(s, f)
|
||||
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f))
|
||||
else: # bash script
|
||||
r = os.system(s)
|
||||
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
|
||||
else:
|
||||
Exception('Dataset autodownload unavailable.')
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
# Returns x evenly divisble by divisor
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
|
Loading…
Reference in New Issue