mirror of https://github.com/alibaba/EasyCV.git
parent
6821cdc9db
commit
19e570adde
|
@ -67,7 +67,7 @@ jobs:
|
|||
PYTHONPATH=. python tests/run.py
|
||||
|
||||
|
||||
# blade test env will be updated!
|
||||
# blade test env will be updated in docker images!
|
||||
# ut-torch181-blade:
|
||||
# # The type of runner that the job will run on
|
||||
# runs-on: [unittest-t4]
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
_base_ = '../../base.py'
|
||||
|
||||
# model settings s m l x
|
||||
model = dict(
|
||||
type='YOLOX',
|
||||
test_conf=0.01,
|
||||
nms_thre=0.65,
|
||||
backbone='RepVGGYOLOX',
|
||||
model_type='s', # s m l x tiny nano
|
||||
head=dict(
|
||||
type='YOLOXHead',
|
||||
model_type='s',
|
||||
obj_loss_type='BCE',
|
||||
reg_loss_type='giou',
|
||||
num_classes=80,
|
||||
decode_in_inference=
|
||||
True # set to False when test speed to ignore decode and nms
|
||||
))
|
||||
|
||||
# s m l x
|
||||
img_scale = (640, 640)
|
||||
random_size = (14, 26)
|
||||
scale_ratio = (0.1, 2)
|
||||
|
||||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='MMMosaic', img_scale=img_scale, pad_val=114.0),
|
||||
dict(
|
||||
type='MMRandomAffine',
|
||||
scaling_ratio_range=scale_ratio,
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
||||
dict(
|
||||
type='MMMixUp', # s m x l; tiny nano will detele
|
||||
img_scale=img_scale,
|
||||
ratio_range=(0.8, 1.6),
|
||||
pad_val=114.0),
|
||||
dict(
|
||||
type='MMPhotoMetricDistortion',
|
||||
brightness_delta=32,
|
||||
contrast_range=(0.5, 1.5),
|
||||
saturation_range=(0.5, 1.5),
|
||||
hue_delta=18),
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
|
||||
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='DetImagesMixDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
||||
img_prefix=data_root + 'train2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
filter_empty_gt=True,
|
||||
iscrowd=False),
|
||||
pipeline=train_pipeline,
|
||||
dynamic_scale=img_scale)
|
||||
|
||||
val_dataset = dict(
|
||||
type='DetImagesMixDataset',
|
||||
imgs_per_gpu=2,
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
img_prefix=data_root + 'val2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
filter_empty_gt=False,
|
||||
test_mode=True,
|
||||
iscrowd=True),
|
||||
pipeline=test_pipeline,
|
||||
dynamic_scale=None,
|
||||
label_padding=False)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)
|
||||
|
||||
# additional hooks
|
||||
interval = 10
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='YOLOXModeSwitchHook',
|
||||
no_aug_epochs=15,
|
||||
skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'),
|
||||
priority=48),
|
||||
dict(
|
||||
type='SyncRandomSizeHook',
|
||||
ratio_range=random_size,
|
||||
img_scale=img_scale,
|
||||
interval=interval,
|
||||
priority=48),
|
||||
dict(
|
||||
type='SyncNormHook',
|
||||
num_last_epochs=15,
|
||||
interval=interval,
|
||||
priority=48)
|
||||
]
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(
|
||||
interval=10,
|
||||
gpu_collect=False,
|
||||
visualization_config=dict(
|
||||
vis_num=10,
|
||||
score_thr=0.5,
|
||||
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
|
||||
)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
data=data['val'],
|
||||
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)],
|
||||
)
|
||||
]
|
||||
|
||||
checkpoint_config = dict(interval=interval)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||||
optimizer_config = {}
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='YOLOX',
|
||||
warmup='exp',
|
||||
by_epoch=False,
|
||||
warmup_by_epoch=True,
|
||||
warmup_ratio=1,
|
||||
warmup_iters=5, # 5 epoch
|
||||
num_last_epochs=15,
|
||||
min_lr_ratio=0.05)
|
||||
|
||||
# exponetial model average
|
||||
ema = dict(decay=0.9998)
|
||||
|
||||
# runtime settings
|
||||
total_epochs = 300
|
||||
|
||||
# yapf:disable
|
||||
log_config = dict(
|
||||
interval=100,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHookV2'),
|
||||
# dict(type='WandbLoggerHookV2'),
|
||||
])
|
||||
|
||||
export = dict(export_type = 'ori', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)
|
|
@ -1,22 +1,27 @@
|
|||
# model settings
|
||||
# models s m l x
|
||||
_base_ = '../../base.py'
|
||||
|
||||
# model settings s m l x
|
||||
model = dict(
|
||||
type='YOLOX',
|
||||
num_classes=80,
|
||||
model_type='tiny', # s m l x tiny nano
|
||||
test_conf=0.01,
|
||||
nms_thre=0.65)
|
||||
nms_thre=0.65,
|
||||
backbone='RepVGGYOLOX',
|
||||
model_type='s', # s m l x tiny nano
|
||||
use_att='ASFF',
|
||||
head=dict(
|
||||
type='YOLOXHead',
|
||||
model_type='s',
|
||||
obj_loss_type='BCE',
|
||||
reg_loss_type='giou',
|
||||
num_classes=80,
|
||||
decode_in_inference=
|
||||
False # set to False when test speed to ignore decode and nms
|
||||
))
|
||||
|
||||
# s m l x
|
||||
# img_scale = (640, 640)
|
||||
# random_size = (14, 26)
|
||||
# scale_ratio = (0.1, 2)
|
||||
|
||||
# tiny nano ; without mixup
|
||||
img_scale = (416, 416)
|
||||
random_size = (10, 20)
|
||||
scale_ratio = (0.5, 1.5)
|
||||
img_scale = (640, 640)
|
||||
random_size = (14, 26)
|
||||
scale_ratio = (0.1, 2)
|
||||
|
||||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
|
@ -36,6 +41,7 @@ CLASSES = [
|
|||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
|
@ -45,6 +51,11 @@ train_pipeline = [
|
|||
type='MMRandomAffine',
|
||||
scaling_ratio_range=scale_ratio,
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
||||
dict(
|
||||
type='MMMixUp', # s m x l; tiny nano will detele
|
||||
img_scale=img_scale,
|
||||
ratio_range=(0.8, 1.6),
|
||||
pad_val=114.0),
|
||||
dict(
|
||||
type='MMPhotoMetricDistortion',
|
||||
brightness_delta=32,
|
||||
|
@ -125,7 +136,14 @@ custom_hooks = [
|
|||
]
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(interval=10, gpu_collect=False)
|
||||
eval_config = dict(
|
||||
interval=10,
|
||||
gpu_collect=False,
|
||||
visualization_config=dict(
|
||||
vis_num=10,
|
||||
score_thr=0.5,
|
||||
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
|
||||
)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
|
@ -137,9 +155,8 @@ eval_pipelines = [
|
|||
checkpoint_config = dict(interval=interval)
|
||||
|
||||
# optimizer
|
||||
# basic_lr_per_img = 0.01 / 64.0
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.01, momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||||
type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||||
optimizer_config = {}
|
||||
|
||||
# learning policy
|
||||
|
@ -164,15 +181,8 @@ log_config = dict(
|
|||
interval=100,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')
|
||||
dict(type='TensorboardLoggerHookV2'),
|
||||
# dict(type='WandbLoggerHookV2'),
|
||||
])
|
||||
# yapf:enable
|
||||
# runtime settings
|
||||
dist_params = dict(backend='nccl')
|
||||
cudnn_benchmark = True
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
workflow = [('train', 1)]
|
||||
|
||||
export = dict(use_jit=False)
|
||||
export = dict(export_type = 'ori', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)
|
|
@ -0,0 +1,189 @@
|
|||
_base_ = '../../base.py'
|
||||
|
||||
# model settings s m l x
|
||||
model = dict(
|
||||
type='YOLOX',
|
||||
test_conf=0.01,
|
||||
nms_thre=0.65,
|
||||
backbone='RepVGGYOLOX',
|
||||
model_type='s', # s m l x tiny nano
|
||||
use_att='ASFF',
|
||||
head=dict(
|
||||
type='TOODHead',
|
||||
model_type='s',
|
||||
obj_loss_type='BCE',
|
||||
reg_loss_type='giou',
|
||||
num_classes=80,
|
||||
decode_in_inference=
|
||||
True # set to False when test speed to ignore decode and nms
|
||||
))
|
||||
|
||||
# s m l x
|
||||
img_scale = (640, 640)
|
||||
random_size = (14, 26)
|
||||
scale_ratio = (0.1, 2)
|
||||
|
||||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='MMMosaic', img_scale=img_scale, pad_val=114.0),
|
||||
dict(
|
||||
type='MMRandomAffine',
|
||||
scaling_ratio_range=scale_ratio,
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
||||
dict(
|
||||
type='MMMixUp', # s m x l; tiny nano will detele
|
||||
img_scale=img_scale,
|
||||
ratio_range=(0.8, 1.6),
|
||||
pad_val=114.0),
|
||||
dict(
|
||||
type='MMPhotoMetricDistortion',
|
||||
brightness_delta=32,
|
||||
contrast_range=(0.5, 1.5),
|
||||
saturation_range=(0.5, 1.5),
|
||||
hue_delta=18),
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
|
||||
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='DetImagesMixDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
||||
img_prefix=data_root + 'train2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
filter_empty_gt=True,
|
||||
iscrowd=False),
|
||||
pipeline=train_pipeline,
|
||||
dynamic_scale=img_scale)
|
||||
|
||||
val_dataset = dict(
|
||||
type='DetImagesMixDataset',
|
||||
imgs_per_gpu=2,
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
img_prefix=data_root + 'val2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
filter_empty_gt=False,
|
||||
test_mode=True,
|
||||
iscrowd=True),
|
||||
pipeline=test_pipeline,
|
||||
dynamic_scale=None,
|
||||
label_padding=False)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)
|
||||
|
||||
# additional hooks
|
||||
interval = 10
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='YOLOXModeSwitchHook',
|
||||
no_aug_epochs=15,
|
||||
skip_type_keys=('MMMosaic', 'MMRandomAffine', 'MMMixUp'),
|
||||
priority=48),
|
||||
dict(
|
||||
type='SyncRandomSizeHook',
|
||||
ratio_range=random_size,
|
||||
img_scale=img_scale,
|
||||
interval=interval,
|
||||
priority=48),
|
||||
dict(
|
||||
type='SyncNormHook',
|
||||
num_last_epochs=15,
|
||||
interval=interval,
|
||||
priority=48)
|
||||
]
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(
|
||||
interval=10,
|
||||
gpu_collect=False,
|
||||
visualization_config=dict(
|
||||
vis_num=10,
|
||||
score_thr=0.5,
|
||||
) # show by TensorboardLoggerHookV2 and WandbLoggerHookV2
|
||||
)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
data=data['val'],
|
||||
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)],
|
||||
)
|
||||
]
|
||||
|
||||
checkpoint_config = dict(interval=interval)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.02, momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||||
optimizer_config = {}
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='YOLOX',
|
||||
warmup='exp',
|
||||
by_epoch=False,
|
||||
warmup_by_epoch=True,
|
||||
warmup_ratio=1,
|
||||
warmup_iters=5, # 5 epoch
|
||||
num_last_epochs=15,
|
||||
min_lr_ratio=0.05)
|
||||
|
||||
# exponetial model average
|
||||
ema = dict(decay=0.9998)
|
||||
|
||||
# runtime settings
|
||||
total_epochs = 300
|
||||
|
||||
# yapf:disable
|
||||
log_config = dict(
|
||||
interval=100,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHookV2'),
|
||||
# dict(type='WandbLoggerHookV2'),
|
||||
])
|
||||
|
||||
export = dict(export_type = 'ori', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)
|
|
@ -1,20 +1,20 @@
|
|||
# Detection Model Zoo
|
||||
|
||||
## YOLOX
|
||||
## YOLOX-PAI
|
||||
|
||||
Pretrained on COCO2017 dataset.
|
||||
|
||||
| Algorithm | Config | Params | Speed<sup>V100<br/><sub>fp16 b32 </sub> | mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|
||||
|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|-----------------------------------------|-------------------------------------|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| YOLOX-s | [yolox_s_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_s_8xb16_300e_coco.py) | 9M | 0.68ms | 40.0 | 58.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/log.txt) |
|
||||
| PAI-YOLOXs | [yoloxs_pai_8xb16_300e_coco](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/config/pai_yoloxs.py) | 16M | 0.71ms | 41.4 | 60.0 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs.json) |
|
||||
| PAI-YOLOXs-ASFF | [yoloxs_pai_asff_8xb16_300e_coco](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/config/pai_yoloxs_asff.py) | 21M | 0.87ms | 42.8 | 61.8 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff.json) |
|
||||
| PAI-YOLOXs-ASFF-TOOD3 | [yoloxs_pai_asff_tood3_8xb16_300e_coco](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/config/pai_yoloxs_asff_tood3.py) | 24M | 1.15ms | 43.9 | 62.1 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff_tood3.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff_tood3.json) |
|
||||
| YOLOX-m | [yolox_m_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb16_300e_coco.py) | 25M | 1.52ms | 46.3 | 64.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/log.txt) |
|
||||
| YOLOX-l | [yolox_l_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb8_300e_coco.py) | 54M | 2.47ms | 48.9 | 67.5 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/log.txt) |
|
||||
| YOLOX-x | [yolox_x_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_x_8xb8_300e_coco.py) | 99M | 4.74ms | 50.9 | 69.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/log.txt) |
|
||||
| YOLOX-tiny | [yolox_tiny_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 5M | 0.28ms | 31.5 | 49.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/log.txt) |
|
||||
| YOLOX-nano | [yolox_nano_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 2.2M | 0.19ms | 26.5 | 42.6 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/log.txt) |
|
||||
| Algorithm | Config | Params | Speed<sup>V100<br/><sub>fp16 b32 </sub> | mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|
||||
|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|--------|-----------------------------------------|-------------------------------------|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| YOLOX-s | [yolox_s_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_s_8xb16_300e_coco.py) | 9M | 0.68ms | 40.0 | 58.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/log.txt) |
|
||||
| PAI-YOLOXs | [yoloxs_pai_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/pai_yoloxs_8xb16_300e_coco.py) | 16M | 0.71ms | 41.4 | 60.0 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs.json) |
|
||||
| PAI-YOLOXs-ASFF | [yoloxs_pai_asff_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/pai_yoloxs_asff_8xb16_300e_coco.py) | 21M | 0.87ms | 42.8 | 61.8 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff.json) |
|
||||
| PAI-YOLOXs-ASFF-TOOD3 | [yoloxs_pai_asff_tood3_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/pai_yoloxs_asff_tood3_8xb16_300e_coco.py) | 24M | 1.15ms | 43.9 | 62.1 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/model/pai_yoloxs_asff_tood3.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox-pai/log/pai_yoloxs_asff_tood3.json) |
|
||||
| YOLOX-m | [yolox_m_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb16_300e_coco.py) | 25M | 1.52ms | 46.3 | 64.9 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_m_bs16_lr002/log.txt) |
|
||||
| YOLOX-l | [yolox_l_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_m_8xb8_300e_coco.py) | 54M | 2.47ms | 48.9 | 67.5 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_l_bs8_lr001/log.txt) |
|
||||
| YOLOX-x | [yolox_x_8xb8_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_x_8xb8_300e_coco.py) | 99M | 4.74ms | 50.9 | 69.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/epoch_290.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_x_bs8_lr001/log.txt) |
|
||||
| YOLOX-tiny | [yolox_tiny_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 5M | 0.28ms | 31.5 | 49.2 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_tiny_bs16_lr002/log.txt) |
|
||||
| YOLOX-nano | [yolox_nano_8xb16_300e_coco](https://github.com/alibaba/EasyCV/tree/master/configs/detection/yolox/yolox_tiny_8xb16_300e_coco.py) | 2.2M | 0.19ms | 26.5 | 42.6 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/epoch_300.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_nano_bs16_lr002/log.txt) |
|
||||
|
||||
## ViTDet
|
||||
|
||||
|
|
Loading…
Reference in New Issue