add github config path for yolox-pai (#166)

* add github config path
pull/170/head
zouxinyi0625 2022-08-26 13:50:48 +08:00 committed by GitHub
parent 6821cdc9db
commit 19e570adde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 426 additions and 39 deletions

View File

@ -67,7 +67,7 @@ jobs:
PYTHONPATH=. python tests/run.py
# blade test env will be updated!
# blade test env will be updated in docker images!
# ut-torch181-blade:
# # The type of runner that the job will run on
# runs-on: [unittest-t4]

View File

@ -0,0 +1,188 @@
_base_ = '../../base.py'
# model settings s m l x
model = dict(
type='YOLOX',
test_conf=0.01,
nms_thre=0.65,
backbone='RepVGGYOLOX',
model_type='s', # s m l x tiny nano
head=dict(
type='YOLOXHead',
model_type='s',
obj_loss_type='BCE',
reg_loss_type='giou',
num_classes=80,
decode_in_inference=
True # set to False when test speed to ignore decode and nms
))
# s m l x
img_scale = (640, 640)
random_size = (14, 26)
scale_ratio = (0.1, 2)
CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
# dataset settings
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='MMMosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='MMRandomAffine',
scaling_ratio_range=scale_ratio,
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MMMixUp', # s m x l; tiny nano will detele
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(
type='MMPhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(type='MMRandomFlip', flip_ratio=0.5),
dict(type='MMResize', keep_ratio=True),
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
]
train_dataset = dict(
type='DetImagesMixDataset',
data_source=dict(
type='DetSourceCoco',
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
],
classes=CLASSES,
filter_empty_gt=True,
iscrowd=False),
pipeline=train_pipeline,
dynamic_scale=img_scale)
val_dataset = dict(
type='DetImagesMixDataset',
imgs_per_gpu=2,
data_source=dict(
type='DetSourceCoco',
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
],
classes=CLASSES,
filter_empty_gt=False,
test_mode=True,
iscrowd=True),
pipeline=test_pipeline,
dynamic_scale=None,
label_padding=False)
data = dict(
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)
# additional hooks
interval = 10
custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
no_aug_epochs=15,
skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'),
priority=48),
dict(
type='SyncRandomSizeHook',
ratio_range=random_size,
img_scale=img_scale,
interval=interval,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=15,
interval=interval,
priority=48)
]
# evaluation
eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)],
)
]
checkpoint_config = dict(interval=interval)
# optimizer
optimizer = dict(
type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer_config = {}
# learning policy
lr_config = dict(
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=15,
min_lr_ratio=0.05)
# exponetial model average
ema = dict(decay=0.9998)
# runtime settings
total_epochs = 300
# yapf:disable
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
])
export = dict(export_type = 'ori', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)

View File

@ -1,22 +1,27 @@
# model settings
# models s m l x
_base_ = '../../base.py'
# model settings s m l x
model = dict(
type='YOLOX',
num_classes=80,
model_type='tiny', # s m l x tiny nano
test_conf=0.01,
nms_thre=0.65)
nms_thre=0.65,
backbone='RepVGGYOLOX',
model_type='s', # s m l x tiny nano
use_att='ASFF',
head=dict(
type='YOLOXHead',
model_type='s',
obj_loss_type='BCE',
reg_loss_type='giou',
num_classes=80,
decode_in_inference=
False # set to False when test speed to ignore decode and nms
))
# s m l x
# img_scale = (640, 640)
# random_size = (14, 26)
# scale_ratio = (0.1, 2)
# tiny nano ; without mixup
img_scale = (416, 416)
random_size = (10, 20)
scale_ratio = (0.5, 1.5)
img_scale = (640, 640)
random_size = (14, 26)
scale_ratio = (0.1, 2)
CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
@ -36,6 +41,7 @@ CLASSES = [
# dataset settings
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
@ -45,6 +51,11 @@ train_pipeline = [
type='MMRandomAffine',
scaling_ratio_range=scale_ratio,
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MMMixUp', # s m x l; tiny nano will detele
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(
type='MMPhotoMetricDistortion',
brightness_delta=32,
@ -125,7 +136,14 @@ custom_hooks = [
]
# evaluation
eval_config = dict(interval=10, gpu_collect=False)
eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [
dict(
mode='test',
@ -137,9 +155,8 @@ eval_pipelines = [
checkpoint_config = dict(interval=interval)
# optimizer
# basic_lr_per_img = 0.01 / 64.0
optimizer = dict(
type='SGD', lr=0.01, momentum=0.9, weight_decay=5e-4, nesterov=True)
type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer_config = {}
# learning policy
@ -164,15 +181,8 @@ log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
])
# yapf:enable
# runtime settings
dist_params = dict(backend='nccl')
cudnn_benchmark = True
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
export = dict(use_jit=False)
export = dict(export_type = 'ori', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)

View File

@ -0,0 +1,189 @@
_base_ = '../../base.py'
# model settings s m l x
model = dict(
type='YOLOX',
test_conf=0.01,
nms_thre=0.65,
backbone='RepVGGYOLOX',
model_type='s', # s m l x tiny nano
use_att='ASFF',
head=dict(
type='TOODHead',
model_type='s',
obj_loss_type='BCE',
reg_loss_type='giou',
num_classes=80,
decode_in_inference=
True # set to False when test speed to ignore decode and nms
))
# s m l x
img_scale = (640, 640)
random_size = (14, 26)
scale_ratio = (0.1, 2)
CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
# dataset settings
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='MMMosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='MMRandomAffine',
scaling_ratio_range=scale_ratio,
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MMMixUp', # s m x l; tiny nano will detele
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(
type='MMPhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(type='MMRandomFlip', flip_ratio=0.5),
dict(type='MMResize', keep_ratio=True),
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
]
train_dataset = dict(
type='DetImagesMixDataset',
data_source=dict(
type='DetSourceCoco',
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
],
classes=CLASSES,
filter_empty_gt=True,
iscrowd=False),
pipeline=train_pipeline,
dynamic_scale=img_scale)
val_dataset = dict(
type='DetImagesMixDataset',
imgs_per_gpu=2,
data_source=dict(
type='DetSourceCoco',
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
],
classes=CLASSES,
filter_empty_gt=False,
test_mode=True,
iscrowd=True),
pipeline=test_pipeline,
dynamic_scale=None,
label_padding=False)
data = dict(
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)
# additional hooks
interval = 10
custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
no_aug_epochs=15,
skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'),
priority=48),
dict(
type='SyncRandomSizeHook',
ratio_range=random_size,
img_scale=img_scale,
interval=interval,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=15,
interval=interval,
priority=48)
]
# evaluation
eval_config = dict(
interval=10,
gpu_collect=False,
visualization_config=dict(
vis_num=10,
score_thr=0.5,
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)],
)
]
checkpoint_config = dict(interval=interval)
# optimizer
optimizer = dict(
type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer_config = {}
# learning policy
lr_config = dict(
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=15,
min_lr_ratio=0.05)
# exponetial model average
ema = dict(decay=0.9998)
# runtime settings
total_epochs = 300
# yapf:disable
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHookV2'),
# dict(type='WandbLoggerHookV2'),
])
export = dict(export_type = 'ori', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)

View File

@ -1,20 +1,20 @@
# Detection Model Zoo
## YOLOX
## YOLOX-PAI
Pretrained on COCO2017 dataset.
| Algorithm | Config | Params | Speed<sup>V100<br/><sub>fp16 b32 </sub> | mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|-----------------------------------------|-------------------------------------|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| YOLOX-s | [yolox_s_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_s_8xb16_300e_coco.py) | 9M | 0.68ms | 40.0 | 58.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/log.txt) |
| PAI-YOLOXs | [yoloxs_pai_8xb16_300e_coco](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/config/pai_yoloxs.py) | 16M | 0.71ms | 41.4 | 60.0 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs.json) |
| PAI-YOLOXs-ASFF | [yoloxs_pai_asff_8xb16_300e_coco](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/config/pai_yoloxs_asff.py) | 21M | 0.87ms | 42.8 | 61.8 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff.json) |
| PAI-YOLOXs-ASFF-TOOD3 | [yoloxs_pai_asff_tood3_8xb16_300e_coco](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/config/pai_yoloxs_asff_tood3.py) | 24M | 1.15ms | 43.9 | 62.1 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff_tood3.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff_tood3.json) |
| YOLOX-m | [yolox_m_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb16_300e_coco.py) | 25M | 1.52ms | 46.3 | 64.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/log.txt) |
| YOLOX-l | [yolox_l_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb8_300e_coco.py) | 54M | 2.47ms | 48.9 | 67.5 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/log.txt) |
| YOLOX-x | [yolox_x_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_x_8xb8_300e_coco.py) | 99M | 4.74ms | 50.9 | 69.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/log.txt) |
| YOLOX-tiny | [yolox_tiny_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 5M | 0.28ms | 31.5 | 49.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/log.txt) |
| YOLOX-nano | [yolox_nano_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 2.2M | 0.19ms | 26.5 | 42.6 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/log.txt) |
| Algorithm | Config | Params | Speed<sup>V100<br/><sub>fp16 b32 </sub> | mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|--------|-----------------------------------------|-------------------------------------|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| YOLOX-s | [yolox_s_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_s_8xb16_300e_coco.py) | 9M | 0.68ms | 40.0 | 58.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/log.txt) |
| PAI-YOLOXs | [yoloxs_pai_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/pai_yoloxs_8xb16_300e_coco.py) | 16M | 0.71ms | 41.4 | 60.0 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs.json) |
| PAI-YOLOXs-ASFF | [yoloxs_pai_asff_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/pai_yoloxs_asff_8xb16_300e_coco.py) | 21M | 0.87ms | 42.8 | 61.8 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff.json) |
| PAI-YOLOXs-ASFF-TOOD3 | [yoloxs_pai_asff_tood3_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/pai_yoloxs_asff_tood3_8xb16_300e_coco.py) | 24M | 1.15ms | 43.9 | 62.1 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff_tood3.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff_tood3.json) |
| YOLOX-m | [yolox_m_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb16_300e_coco.py) | 25M | 1.52ms | 46.3 | 64.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/log.txt) |
| YOLOX-l | [yolox_l_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb8_300e_coco.py) | 54M | 2.47ms | 48.9 | 67.5 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/log.txt) |
| YOLOX-x | [yolox_x_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_x_8xb8_300e_coco.py) | 99M | 4.74ms | 50.9 | 69.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/log.txt) |
| YOLOX-tiny | [yolox_tiny_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 5M | 0.28ms | 31.5 | 49.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/log.txt) |
| YOLOX-nano | [yolox_nano_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 2.2M | 0.19ms | 26.5 | 42.6 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/log.txt) |
## ViTDet