semi-supervised benchmark

pull/4/head
xiaohangzhan 2020-06-16 16:30:46 +08:00
parent bc3c11e441
commit 038d727ad4
6 changed files with 64 additions and 42 deletions

View File

@ -24,7 +24,7 @@ Below is the relations among Unsupervised Learning, Self-Supervised Learning and
| [Rotation-Pred](https://arxiv.org/abs/1803.07728) | ✓ |
| [DeepCluster](https://arxiv.org/abs/1807.05520) | ✓ |
| [ODC](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhan_Online_Deep_Clustering_for_Unsupervised_Representation_Learning_CVPR_2020_paper.pdf) | [34m~\~S |
| [NIPD](https://arxiv.org/abs/1805.01978) | ✓ |
| [NPID](https://arxiv.org/abs/1805.01978) | ✓ |
| [MoCo](https://arxiv.org/abs/1911.05722) | ✓ |
| [MoCo v2](https://arxiv.org/abs/2003.04297) | ✓ |
| [SimCLR](https://arxiv.org/abs/2002.05709) | ✓ |

View File

@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -e
set -x
CFG=$1
EPOCH=$2
PERCENT=$3
PY_ARGS=${@:4}
GPUS=${GPUS:-8}
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
CHECKPOINT=$WORK_DIR/epoch_${EPOCH}.pth
WORK_DIR_EVAL=$WORK_DIR/imagenet_semi_${PERCENT}percent_at_epoch_${EPOCH}/
if [ ! "$PERCENT" == "1" ] && [ ! "$PERCENT" == 10 ]; then
echo "ERROR: PERCENT must in {1, 10}"
exit
fi
# extract backbone
if [ ! -f "${CHECKPOINT::(-4)}_extracted.pth" ]; then
python tools/extract_backbone_weights.py $CHECKPOINT \
--save-path ${CHECKPOINT::(-4)}_extracted.pth
fi
# train
python -m torch.distributed.launch --nproc_per_node=$GPUS \
tools/train.py \
configs/semisup_classification/imagenet_${PERCENT}percent/r50.py \
--pretrained ${CHECKPOINT::(-4)}_extracted.pth \
--work_dir ${WORK_DIR_EVAL} --seed 0 --launcher="pytorch" ${PY_ARGS}
# test
python -m torch.distributed.launch --nproc_per_node=$GPUS \
tools/test.py \
configs/semisup_classification/imagenet_${PERCENT}percent/r50.py \
${WORK_DIR_EVAL}/latest.pth \
--work_dir ${WORK_DIR_EVAL} --launcher="pytorch"

View File

@ -9,8 +9,8 @@ EPOCH=$3
PERCENT=$4
PY_ARGS=${@:5}
JOB_NAME="openselfsup"
GPUS=${GPUS:-1}
GPUS_PER_NODE=${GPUS_PER_NODE:-1}
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}

View File

@ -72,6 +72,15 @@ Augments:
- `${DATASET}` in `['imagenet', 'places205']`.
- Optional arguments include `--resume_from ${CHECKPOINT_FILE}` that resume from a previous checkpoint file.
### ImageNet Semi-Supervised Classification
```shell
bash benchmarks/dist_test_semi.sh ${CONFIG_FILE} ${EPOCH} ${PERCENT} [optional arguments]
``
Arguments:
- `${PERCENT}` in `[1, 10]`.
- Other arguments are the same as in ImageNet / Places205 Linear Classification.
### VOC07+12 / COCO17 Object Detection
1. First, extract backbone weights:

View File

@ -11,7 +11,13 @@ class ImageList(object):
def __init__(self, root, list_file, memcached, mclient_path):
with open(list_file, 'r') as f:
lines = f.readlines()
self.fns = [os.path.join(root, l.strip()) for l in lines]
self.has_labels = len(lines[0].split()) == 2
if self.has_labels:
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
self.labels = [int(l) for l in self.labels]
else:
self.fns = [l.strip() for l in lines]
self.fns = [os.path.join(root, fn) for fn in self.fns]
self.memcached = memcached
self.mclient_path = mclient_path
self.initialized = False
@ -33,4 +39,5 @@ class ImageList(object):
else:
img = Image.open(self.fns[idx])
img = img.convert('RGB')
return img
target = self.labels[idx] if self.has_labels else None
return img, target

View File

@ -1,43 +1,10 @@
import os
from PIL import Image
from ..registry import DATASOURCES
from .utils import McLoader
from .image_list import ImageList
@DATASOURCES.register_module
class ImageNet(object):
class ImageNet(ImageList):
def __init__(self, root, list_file, memcached, mclient_path):
with open(list_file, 'r') as f:
lines = f.readlines()
self.has_labels = len(lines[0].split()) == 2
if self.has_labels:
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
self.labels = [int(l) for l in self.labels]
else:
self.fns = [l.strip() for l in lines]
self.fns = [os.path.join(root, fn) for fn in self.fns]
self.memcached = memcached
self.mclient_path = mclient_path
self.initialized = False
def _init_memcached(self):
if not self.initialized:
assert self.mclient_path is not None
self.mc_loader = McLoader(self.mclient_path)
self.initialized = True
def get_length(self):
return len(self.fns)
def get_sample(self, idx):
if self.memcached:
self._init_memcached()
if self.memcached:
img = self.mc_loader(self.fns[idx])
else:
img = Image.open(self.fns[idx])
img = img.convert('RGB')
target = self.labels[idx] if self.has_labels else None
return img, target
super(ImageNet, self).__init__(
root, list_file, memcached, mclient_path)