diff --git a/configs/_base_/datasets/imagenet_bs128_poolformer_small_224_custom.py b/configs/_base_/datasets/imagenet_bs128_poolformer_small_224_custom.py index a1da9c27..16d410df 100644 --- a/configs/_base_/datasets/imagenet_bs128_poolformer_small_224_custom.py +++ b/configs/_base_/datasets/imagenet_bs128_poolformer_small_224_custom.py @@ -85,6 +85,7 @@ val_dataloader = dict( sampler=dict(type='DefaultSampler', shuffle=False), ) +# If you want standard test, please manually configure the test dataset test_dataloader = dict( batch_size=32, num_workers=5, @@ -100,7 +101,9 @@ val_evaluator = [ dict(type='AveragePrecision', prefix='val'), dict(type='SingleLabelMetric', prefix='val'), ] +test_evaluator = [ + dict(type='Accuracy', topk=(1,), prefix='test'), + dict(type='AveragePrecision', prefix='test'), + dict(type='SingleLabelMetric', prefix='test'), +] -# If you want standard test, please manually configure the test dataset -test_dataloader = test_dataloader -test_evaluator = val_evaluator diff --git a/docker/Dockerfile b/docker/Dockerfile index 450cb51b..90f653ea 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -21,8 +21,8 @@ RUN pip install openmim # Install MMPretrain RUN conda clean --all -RUN git clone https://github.com/logivations/mmpretrain.git -WORKDIR ./mmpretrain +RUN git clone https://github.com/logivations/mmpretrain.git /mmpretrain +WORKDIR /mmpretrain # Worked version mim list -> mmpretrain 1.0.0rc7 /workspace/mmpretrain RUN mim install --no-cache-dir -e . diff --git a/tools/train.py b/tools/train.py index c42d420e..9768ba5a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -3,7 +3,7 @@ import argparse import os import os.path as osp from copy import deepcopy - +import mmengine from mmengine.config import Config, ConfigDict, DictAction from mmengine.runner import Runner from mmengine.utils import digit_version @@ -155,6 +155,8 @@ def main(): # start training runner.train() + metrics = runner.test() + mmengine.dump(metrics, os.path.join(args.work_dir, "metrics.json")) if __name__ == '__main__':