semi-supervised benchmark
parent
bc3c11e441
commit
038d727ad4
|
@ -24,7 +24,7 @@ Below is the relations among Unsupervised Learning, Self-Supervised Learning and
|
|||
| [Rotation-Pred](https://arxiv.org/abs/1803.07728) | ✓ |
|
||||
| [DeepCluster](https://arxiv.org/abs/1807.05520) | ✓ |
|
||||
| [ODC](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf) | [34m~\~S |
|
||||
| [NIPD](https://arxiv.org/abs/1805.01978) | ✓ |
|
||||
| [NPID](https://arxiv.org/abs/1805.01978) | ✓ |
|
||||
| [MoCo](https://arxiv.org/abs/1911.05722) | ✓ |
|
||||
| [MoCo v2](https://arxiv.org/abs/2003.04297) | ✓ |
|
||||
| [SimCLR](https://arxiv.org/abs/2002.05709) | ✓ |
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
CFG=$1
|
||||
EPOCH=$2
|
||||
PERCENT=$3
|
||||
PY_ARGS=${@:4}
|
||||
GPUS=${GPUS:-8}
|
||||
|
||||
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
||||
CHECKPOINT=$WORK_DIR/epoch_${EPOCH}.pth
|
||||
WORK_DIR_EVAL=$WORK_DIR/imagenet_semi_${PERCENT}percent_at_epoch_${EPOCH}/
|
||||
|
||||
if [ ! "$PERCENT" == "1" ] && [ ! "$PERCENT" == 10 ]; then
|
||||
echo "ERROR: PERCENT must in {1, 10}"
|
||||
exit
|
||||
fi
|
||||
|
||||
# extract backbone
|
||||
if [ ! -f "${CHECKPOINT::(-4)}_extracted.pth" ]; then
|
||||
python tools/extract_backbone_weights.py $CHECKPOINT \
|
||||
--save-path ${CHECKPOINT::(-4)}_extracted.pth
|
||||
fi
|
||||
|
||||
# train
|
||||
python -m torch.distributed.launch --nproc_per_node=$GPUS \
|
||||
tools/train.py \
|
||||
configs/semisup_classification/imagenet_${PERCENT}percent/r50.py \
|
||||
--pretrained ${CHECKPOINT::(-4)}_extracted.pth \
|
||||
--work_dir ${WORK_DIR_EVAL} --seed 0 --launcher="pytorch" ${PY_ARGS}
|
||||
|
||||
# test
|
||||
python -m torch.distributed.launch --nproc_per_node=$GPUS \
|
||||
tools/test.py \
|
||||
configs/semisup_classification/imagenet_${PERCENT}percent/r50.py \
|
||||
${WORK_DIR_EVAL}/latest.pth \
|
||||
--work_dir ${WORK_DIR_EVAL} --launcher="pytorch"
|
|
@ -9,8 +9,8 @@ EPOCH=$3
|
|||
PERCENT=$4
|
||||
PY_ARGS=${@:5}
|
||||
JOB_NAME="openselfsup"
|
||||
GPUS=${GPUS:-1}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-1}
|
||||
GPUS=${GPUS:-8}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||
|
||||
|
|
|
@ -72,6 +72,15 @@ Augments:
|
|||
- `${DATASET}` in `['imagenet', 'places205']`.
|
||||
- Optional arguments include `--resume_from ${CHECKPOINT_FILE}` that resume from a previous checkpoint file.
|
||||
|
||||
### ImageNet Semi-Supervised Classification
|
||||
|
||||
```shell
|
||||
bash benchmarks/dist_test_semi.sh ${CONFIG_FILE} ${EPOCH} ${PERCENT} [optional arguments]
|
||||
``
|
||||
Arguments:
|
||||
- `${PERCENT}` in `[1, 10]`.
|
||||
- Other arguments are the same as in ImageNet / Places205 Linear Classification.
|
||||
|
||||
### VOC07+12 / COCO17 Object Detection
|
||||
|
||||
1. First, extract backbone weights:
|
||||
|
|
|
@ -11,7 +11,13 @@ class ImageList(object):
|
|||
def __init__(self, root, list_file, memcached, mclient_path):
|
||||
with open(list_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
self.fns = [os.path.join(root, l.strip()) for l in lines]
|
||||
self.has_labels = len(lines[0].split()) == 2
|
||||
if self.has_labels:
|
||||
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
||||
self.labels = [int(l) for l in self.labels]
|
||||
else:
|
||||
self.fns = [l.strip() for l in lines]
|
||||
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
||||
self.memcached = memcached
|
||||
self.mclient_path = mclient_path
|
||||
self.initialized = False
|
||||
|
@ -33,4 +39,5 @@ class ImageList(object):
|
|||
else:
|
||||
img = Image.open(self.fns[idx])
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
target = self.labels[idx] if self.has_labels else None
|
||||
return img, target
|
||||
|
|
|
@ -1,43 +1,10 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
|
||||
from ..registry import DATASOURCES
|
||||
from .utils import McLoader
|
||||
from .image_list import ImageList
|
||||
|
||||
|
||||
@DATASOURCES.register_module
|
||||
class ImageNet(object):
|
||||
class ImageNet(ImageList):
|
||||
|
||||
def __init__(self, root, list_file, memcached, mclient_path):
|
||||
with open(list_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
self.has_labels = len(lines[0].split()) == 2
|
||||
if self.has_labels:
|
||||
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
||||
self.labels = [int(l) for l in self.labels]
|
||||
else:
|
||||
self.fns = [l.strip() for l in lines]
|
||||
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
||||
self.memcached = memcached
|
||||
self.mclient_path = mclient_path
|
||||
self.initialized = False
|
||||
|
||||
def _init_memcached(self):
|
||||
if not self.initialized:
|
||||
assert self.mclient_path is not None
|
||||
self.mc_loader = McLoader(self.mclient_path)
|
||||
self.initialized = True
|
||||
|
||||
def get_length(self):
|
||||
return len(self.fns)
|
||||
|
||||
def get_sample(self, idx):
|
||||
if self.memcached:
|
||||
self._init_memcached()
|
||||
if self.memcached:
|
||||
img = self.mc_loader(self.fns[idx])
|
||||
else:
|
||||
img = Image.open(self.fns[idx])
|
||||
img = img.convert('RGB')
|
||||
target = self.labels[idx] if self.has_labels else None
|
||||
return img, target
|
||||
super(ImageNet, self).__init__(
|
||||
root, list_file, memcached, mclient_path)
|
||||
|
|
Loading…
Reference in New Issue