added test epoch to training model

pull/1879/head
rostyslavhereha 2024-03-04 19:39:26 +02:00
parent ba804d8815
commit 85626d924b
3 changed files with 11 additions and 6 deletions

View File

@ -85,6 +85,7 @@ val_dataloader = dict(
sampler=dict(type='DefaultSampler', shuffle=False),
)
# If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=32,
num_workers=5,
@ -100,7 +101,9 @@ val_evaluator = [
dict(type='AveragePrecision', prefix='val'),
dict(type='SingleLabelMetric', prefix='val'),
]
test_evaluator = [
dict(type='Accuracy', topk=(1,), prefix='test'),
dict(type='AveragePrecision', prefix='test'),
dict(type='SingleLabelMetric', prefix='test'),
]
# If you want standard test, please manually configure the test dataset
test_dataloader = test_dataloader
test_evaluator = val_evaluator

View File

@ -21,8 +21,8 @@ RUN pip install openmim
# Install MMPretrain
RUN conda clean --all
RUN git clone https://github.com/logivations/mmpretrain.git
WORKDIR ./mmpretrain
RUN git clone https://github.com/logivations/mmpretrain.git /mmpretrain
WORKDIR /mmpretrain
# Worked version mim list -> mmpretrain 1.0.0rc7 /workspace/mmpretrain
RUN mim install --no-cache-dir -e .

View File

@ -3,7 +3,7 @@ import argparse
import os
import os.path as osp
from copy import deepcopy
import mmengine
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.runner import Runner
from mmengine.utils import digit_version
@ -155,6 +155,8 @@ def main():
# start training
runner.train()
metrics = runner.test()
mmengine.dump(metrics, os.path.join(args.work_dir, "metrics.json"))
if __name__ == '__main__':