update training instruction

Summary: update dataset configuration and training instruction
pull/150/head
liaoxingyu 2020-06-16 19:43:36 +08:00
parent 727a746831
commit 8879db3fba
7 changed files with 54 additions and 9 deletions

View File

@ -12,3 +12,25 @@ Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
```bash
cd fastreid/evaluation/rank_cylib; make all
```
## Training & Evaluation in Command Line
We provide a script in "tools/train_net.py", that is made to train all the configs provided in fastreid.
You may want to use it as a reference to write your own training script.
To train a model with "train_net.py", first setup up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run:
```bash
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml
```
The configs are made for 1-GPU training.
To evaluate a model's performance, use
```bash
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml \
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
```
For more options, see `./train_net.py -h`.

View File

@ -2,9 +2,9 @@
Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below.
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `./datasets` relative to your current working directory.
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `datasets/` relative to your current working directory.
The model zoo contains configs and models that use these buildin datasets.
The [model zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md) contains configs and models that use these buildin datasets.
## Expected dataset structure for Market1501
@ -20,4 +20,24 @@ datasets/
## Expected dataset structure for DukeMTMC
1. Download datasets to `datasets/`
2. Extract dataset. The dataset structure would like:
```bash
datasets/
DukeMTMC-reID/
bounding_box_train/
bounding_box_test/
```
## Expected dataset structure for MSMT17
1. Download datasets to `datasets/`
2. Extract dataset. The dataset structure would like:
```bash
datasets/
MSMT17_V2/
mask_train_v2/
mask_test_v2/
```

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
import os
import torch
from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader
@ -13,13 +14,15 @@ from .common import CommDataset
from .datasets import DATASET_REGISTRY
from .transforms import build_transforms
_root = os.getenv("FASTREID_DATASETS", "datasets")
def build_reid_train_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True)
train_items = list()
for d in cfg.DATASETS.NAMES:
dataset = DATASET_REGISTRY.get(d)(combineall=cfg.DATASETS.COMBINEALL)
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
dataset.show_train()
train_items.extend(dataset.train)
@ -50,7 +53,7 @@ def build_reid_train_loader(cfg):
def build_reid_test_loader(cfg, dataset_name):
test_transforms = build_transforms(cfg, is_train=False)
dataset = DATASET_REGISTRY.get(dataset_name)()
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
dataset.show_test()
test_items = dataset.query + dataset.gallery

View File

@ -19,3 +19,5 @@ from .msmt17 import MSMT17
from .veri import VeRi
from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]

View File

@ -5,11 +5,8 @@
"""
import copy
import os
import numpy as np
import torch
import logging
import os
class Dataset(object):

View File

@ -95,7 +95,7 @@ class Checkpointer(object):
if not path:
# no checkpoint provided
self.logger.info(
"No checkpoint found. Initializing model from scratch"
"No checkpoint found. Training model from scratch"
)
return {}
self.logger.info("Loading checkpoint from {}".format(path))

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python
# encoding: utf-8
"""
@author: sherlock