semi-supervised benchmark
parent
bc3c11e441
commit
038d727ad4
|
@ -24,7 +24,7 @@ Below is the relations among Unsupervised Learning, Self-Supervised Learning and
|
||||||
| [Rotation-Pred](https://arxiv.org/abs/1803.07728) | ✓ |
|
| [Rotation-Pred](https://arxiv.org/abs/1803.07728) | ✓ |
|
||||||
| [DeepCluster](https://arxiv.org/abs/1807.05520) | ✓ |
|
| [DeepCluster](https://arxiv.org/abs/1807.05520) | ✓ |
|
||||||
| [ODC](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf) | [34m~\~S |
|
| [ODC](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf) | [34m~\~S |
|
||||||
| [NIPD](https://arxiv.org/abs/1805.01978) | ✓ |
|
| [NPID](https://arxiv.org/abs/1805.01978) | ✓ |
|
||||||
| [MoCo](https://arxiv.org/abs/1911.05722) | ✓ |
|
| [MoCo](https://arxiv.org/abs/1911.05722) | ✓ |
|
||||||
| [MoCo v2](https://arxiv.org/abs/2003.04297) | ✓ |
|
| [MoCo v2](https://arxiv.org/abs/2003.04297) | ✓ |
|
||||||
| [SimCLR](https://arxiv.org/abs/2002.05709) | ✓ |
|
| [SimCLR](https://arxiv.org/abs/2002.05709) | ✓ |
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -x
|
||||||
|
|
||||||
|
CFG=$1
|
||||||
|
EPOCH=$2
|
||||||
|
PERCENT=$3
|
||||||
|
PY_ARGS=${@:4}
|
||||||
|
GPUS=${GPUS:-8}
|
||||||
|
|
||||||
|
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
||||||
|
CHECKPOINT=$WORK_DIR/epoch_${EPOCH}.pth
|
||||||
|
WORK_DIR_EVAL=$WORK_DIR/imagenet_semi_${PERCENT}percent_at_epoch_${EPOCH}/
|
||||||
|
|
||||||
|
if [ ! "$PERCENT" == "1" ] && [ ! "$PERCENT" == 10 ]; then
|
||||||
|
echo "ERROR: PERCENT must in {1, 10}"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
|
# extract backbone
|
||||||
|
if [ ! -f "${CHECKPOINT::(-4)}_extracted.pth" ]; then
|
||||||
|
python tools/extract_backbone_weights.py $CHECKPOINT \
|
||||||
|
--save-path ${CHECKPOINT::(-4)}_extracted.pth
|
||||||
|
fi
|
||||||
|
|
||||||
|
# train
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=$GPUS \
|
||||||
|
tools/train.py \
|
||||||
|
configs/semisup_classification/imagenet_${PERCENT}percent/r50.py \
|
||||||
|
--pretrained ${CHECKPOINT::(-4)}_extracted.pth \
|
||||||
|
--work_dir ${WORK_DIR_EVAL} --seed 0 --launcher="pytorch" ${PY_ARGS}
|
||||||
|
|
||||||
|
# test
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=$GPUS \
|
||||||
|
tools/test.py \
|
||||||
|
configs/semisup_classification/imagenet_${PERCENT}percent/r50.py \
|
||||||
|
${WORK_DIR_EVAL}/latest.pth \
|
||||||
|
--work_dir ${WORK_DIR_EVAL} --launcher="pytorch"
|
|
@ -9,8 +9,8 @@ EPOCH=$3
|
||||||
PERCENT=$4
|
PERCENT=$4
|
||||||
PY_ARGS=${@:5}
|
PY_ARGS=${@:5}
|
||||||
JOB_NAME="openselfsup"
|
JOB_NAME="openselfsup"
|
||||||
GPUS=${GPUS:-1}
|
GPUS=${GPUS:-8}
|
||||||
GPUS_PER_NODE=${GPUS_PER_NODE:-1}
|
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,15 @@ Augments:
|
||||||
- `${DATASET}` in `['imagenet', 'places205']`.
|
- `${DATASET}` in `['imagenet', 'places205']`.
|
||||||
- Optional arguments include `--resume_from ${CHECKPOINT_FILE}` that resume from a previous checkpoint file.
|
- Optional arguments include `--resume_from ${CHECKPOINT_FILE}` that resume from a previous checkpoint file.
|
||||||
|
|
||||||
|
### ImageNet Semi-Supervised Classification
|
||||||
|
|
||||||
|
```shell
|
||||||
|
bash benchmarks/dist_test_semi.sh ${CONFIG_FILE} ${EPOCH} ${PERCENT} [optional arguments]
|
||||||
|
``
|
||||||
|
Arguments:
|
||||||
|
- `${PERCENT}` in `[1, 10]`.
|
||||||
|
- Other arguments are the same as in ImageNet / Places205 Linear Classification.
|
||||||
|
|
||||||
### VOC07+12 / COCO17 Object Detection
|
### VOC07+12 / COCO17 Object Detection
|
||||||
|
|
||||||
1. First, extract backbone weights:
|
1. First, extract backbone weights:
|
||||||
|
|
|
@ -11,7 +11,13 @@ class ImageList(object):
|
||||||
def __init__(self, root, list_file, memcached, mclient_path):
|
def __init__(self, root, list_file, memcached, mclient_path):
|
||||||
with open(list_file, 'r') as f:
|
with open(list_file, 'r') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
self.fns = [os.path.join(root, l.strip()) for l in lines]
|
self.has_labels = len(lines[0].split()) == 2
|
||||||
|
if self.has_labels:
|
||||||
|
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
||||||
|
self.labels = [int(l) for l in self.labels]
|
||||||
|
else:
|
||||||
|
self.fns = [l.strip() for l in lines]
|
||||||
|
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
||||||
self.memcached = memcached
|
self.memcached = memcached
|
||||||
self.mclient_path = mclient_path
|
self.mclient_path = mclient_path
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
@ -33,4 +39,5 @@ class ImageList(object):
|
||||||
else:
|
else:
|
||||||
img = Image.open(self.fns[idx])
|
img = Image.open(self.fns[idx])
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
return img
|
target = self.labels[idx] if self.has_labels else None
|
||||||
|
return img, target
|
||||||
|
|
|
@ -1,43 +1,10 @@
|
||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..registry import DATASOURCES
|
from ..registry import DATASOURCES
|
||||||
from .utils import McLoader
|
from .image_list import ImageList
|
||||||
|
|
||||||
|
|
||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class ImageNet(object):
|
class ImageNet(ImageList):
|
||||||
|
|
||||||
def __init__(self, root, list_file, memcached, mclient_path):
|
def __init__(self, root, list_file, memcached, mclient_path):
|
||||||
with open(list_file, 'r') as f:
|
super(ImageNet, self).__init__(
|
||||||
lines = f.readlines()
|
root, list_file, memcached, mclient_path)
|
||||||
self.has_labels = len(lines[0].split()) == 2
|
|
||||||
if self.has_labels:
|
|
||||||
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
|
||||||
self.labels = [int(l) for l in self.labels]
|
|
||||||
else:
|
|
||||||
self.fns = [l.strip() for l in lines]
|
|
||||||
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
|
||||||
self.memcached = memcached
|
|
||||||
self.mclient_path = mclient_path
|
|
||||||
self.initialized = False
|
|
||||||
|
|
||||||
def _init_memcached(self):
|
|
||||||
if not self.initialized:
|
|
||||||
assert self.mclient_path is not None
|
|
||||||
self.mc_loader = McLoader(self.mclient_path)
|
|
||||||
self.initialized = True
|
|
||||||
|
|
||||||
def get_length(self):
|
|
||||||
return len(self.fns)
|
|
||||||
|
|
||||||
def get_sample(self, idx):
|
|
||||||
if self.memcached:
|
|
||||||
self._init_memcached()
|
|
||||||
if self.memcached:
|
|
||||||
img = self.mc_loader(self.fns[idx])
|
|
||||||
else:
|
|
||||||
img = Image.open(self.fns[idx])
|
|
||||||
img = img.convert('RGB')
|
|
||||||
target = self.labels[idx] if self.has_labels else None
|
|
||||||
return img, target
|
|
||||||
|
|
Loading…
Reference in New Issue