added test epoch to training model
parent
ba804d8815
commit
85626d924b
|
@ -85,6 +85,7 @@ val_dataloader = dict(
|
|||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
|
@ -100,7 +101,9 @@ val_evaluator = [
|
|||
dict(type='AveragePrecision', prefix='val'),
|
||||
dict(type='SingleLabelMetric', prefix='val'),
|
||||
]
|
||||
test_evaluator = [
|
||||
dict(type='Accuracy', topk=(1,), prefix='test'),
|
||||
dict(type='AveragePrecision', prefix='test'),
|
||||
dict(type='SingleLabelMetric', prefix='test'),
|
||||
]
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = test_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
|
|
@ -21,8 +21,8 @@ RUN pip install openmim
|
|||
|
||||
# Install MMPretrain
|
||||
RUN conda clean --all
|
||||
RUN git clone https://github.com/logivations/mmpretrain.git
|
||||
WORKDIR ./mmpretrain
|
||||
RUN git clone https://github.com/logivations/mmpretrain.git /mmpretrain
|
||||
WORKDIR /mmpretrain
|
||||
|
||||
# Worked version mim list -> mmpretrain 1.0.0rc7 /workspace/mmpretrain
|
||||
RUN mim install --no-cache-dir -e .
|
||||
|
|
|
@ -3,7 +3,7 @@ import argparse
|
|||
import os
|
||||
import os.path as osp
|
||||
from copy import deepcopy
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict, DictAction
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.utils import digit_version
|
||||
|
@ -155,6 +155,8 @@ def main():
|
|||
|
||||
# start training
|
||||
runner.train()
|
||||
metrics = runner.test()
|
||||
mmengine.dump(metrics, os.path.join(args.work_dir, "metrics.json"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue