update training instruction

Summary: update dataset configuration and training instruction
pull/150/head
liaoxingyu 2020-06-16 19:43:36 +08:00
parent 727a746831
commit 8879db3fba
7 changed files with 54 additions and 9 deletions

View File

@ -12,3 +12,25 @@ Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
```bash ```bash
cd fastreid/evaluation/rank_cylib; make all cd fastreid/evaluation/rank_cylib; make all
``` ```
## Training & Evaluation in Command Line
We provide a script in "tools/train_net.py", that is made to train all the configs provided in fastreid.
You may want to use it as a reference to write your own training script.
To train a model with "train_net.py", first setup up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run:
```bash
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml
```
The configs are made for 1-GPU training.
To evaluate a model's performance, use
```bash
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml \
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
```
For more options, see `./train_net.py -h`.

View File

@ -2,9 +2,9 @@
Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below. Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below.
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `./datasets` relative to your current working directory. You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `datasets/` relative to your current working directory.
The model zoo contains configs and models that use these buildin datasets. The [model zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md) contains configs and models that use these buildin datasets.
## Expected dataset structure for Market1501 ## Expected dataset structure for Market1501
@ -20,4 +20,24 @@ datasets/
## Expected dataset structure for DukeMTMC ## Expected dataset structure for DukeMTMC
1. Download datasets to `datasets/`
2. Extract dataset. The dataset structure would like:
```bash
datasets/
DukeMTMC-reID/
bounding_box_train/
bounding_box_test/
```
## Expected dataset structure for MSMT17 ## Expected dataset structure for MSMT17
1. Download datasets to `datasets/`
2. Extract dataset. The dataset structure would like:
```bash
datasets/
MSMT17_V2/
mask_train_v2/
mask_test_v2/
```

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
import os
import torch import torch
from torch._six import container_abcs, string_classes, int_classes from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -13,13 +14,15 @@ from .common import CommDataset
from .datasets import DATASET_REGISTRY from .datasets import DATASET_REGISTRY
from .transforms import build_transforms from .transforms import build_transforms
_root = os.getenv("FASTREID_DATASETS", "datasets")
def build_reid_train_loader(cfg): def build_reid_train_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True) train_transforms = build_transforms(cfg, is_train=True)
train_items = list() train_items = list()
for d in cfg.DATASETS.NAMES: for d in cfg.DATASETS.NAMES:
dataset = DATASET_REGISTRY.get(d)(combineall=cfg.DATASETS.COMBINEALL) dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
dataset.show_train() dataset.show_train()
train_items.extend(dataset.train) train_items.extend(dataset.train)
@ -50,7 +53,7 @@ def build_reid_train_loader(cfg):
def build_reid_test_loader(cfg, dataset_name): def build_reid_test_loader(cfg, dataset_name):
test_transforms = build_transforms(cfg, is_train=False) test_transforms = build_transforms(cfg, is_train=False)
dataset = DATASET_REGISTRY.get(dataset_name)() dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
dataset.show_test() dataset.show_test()
test_items = dataset.query + dataset.gallery test_items = dataset.query + dataset.gallery

View File

@ -19,3 +19,5 @@ from .msmt17 import MSMT17
from .veri import VeRi from .veri import VeRi
from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]

View File

@ -5,11 +5,8 @@
""" """
import copy import copy
import os
import numpy as np
import torch
import logging import logging
import os
class Dataset(object): class Dataset(object):

View File

@ -95,7 +95,7 @@ class Checkpointer(object):
if not path: if not path:
# no checkpoint provided # no checkpoint provided
self.logger.info( self.logger.info(
"No checkpoint found. Initializing model from scratch" "No checkpoint found. Training model from scratch"
) )
return {} return {}
self.logger.info("Loading checkpoint from {}".format(path)) self.logger.info("Loading checkpoint from {}".format(path))

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python
# encoding: utf-8 # encoding: utf-8
""" """
@author: sherlock @author: sherlock