mirror of https://github.com/alibaba/EasyCV.git
adapt_videopredictor_to_eas (#268)
parent
8379127388
commit
3953bd2bd4
|
@ -28,8 +28,10 @@ class ClipBertTwoStream(BaseModel):
|
|||
self.test_cfg = test_cfg
|
||||
self.multi_class = multi_class
|
||||
|
||||
self.vison_pretrained = get_checkpoint(vison_pretrained)
|
||||
self.text_pretrained = get_checkpoint(text_pretrained)
|
||||
self.vison_pretrained = get_checkpoint(
|
||||
vison_pretrained) if vison_pretrained else vison_pretrained
|
||||
self.text_pretrained = get_checkpoint(
|
||||
text_pretrained) if text_pretrained else text_pretrained
|
||||
loss_cls = dict(type='CrossEntropyLoss') if not multi_class else dict(
|
||||
type='AsymmetricLoss')
|
||||
self.loss_cls = builder.build_loss(loss_cls)
|
||||
|
|
|
@ -30,7 +30,7 @@ class Recognizer3D(BaseModel):
|
|||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.pretrained = get_checkpoint(
|
||||
pretrained) if pretrained != None else pretrained
|
||||
pretrained) if pretrained else pretrained
|
||||
self.activate_fn = nn.Softmax(dim=1)
|
||||
|
||||
# aux_info is the list of tensor names beyond 'imgs' and 'label' which
|
||||
|
|
|
@ -8,7 +8,10 @@ from PIL import Image, ImageFile
|
|||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file import io
|
||||
from easycv.framework.errors import ValueError
|
||||
from easycv.models.builder import build_model
|
||||
from easycv.utils.misc import deprecated
|
||||
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
|
||||
remove_adapt_for_mmlab)
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .base import Predictor, PredictorV2
|
||||
from .builder import PREDICTORS
|
||||
|
@ -100,11 +103,16 @@ class VideoClassificationPredictor(PredictorV2):
|
|||
"""
|
||||
if self.pipelines is not None:
|
||||
pipelines = self.pipelines
|
||||
else:
|
||||
elif 'test_pipeline' in self.cfg:
|
||||
pipelines = self.cfg.get('test_pipeline', [])
|
||||
else:
|
||||
pipelines = self.cfg.get('val_pipeline', [])
|
||||
for idx, pipeline in enumerate(pipelines):
|
||||
if pipeline['type'] == 'Collect' and 'label' in pipeline['keys']:
|
||||
pipeline['keys'].remove('label')
|
||||
if pipeline['type'] == 'VideoToTensor' and 'label' in pipeline[
|
||||
'keys']:
|
||||
pipeline['keys'].remove('label')
|
||||
pipelines[idx] = pipeline
|
||||
|
||||
pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines]
|
||||
|
@ -113,6 +121,19 @@ class VideoClassificationPredictor(PredictorV2):
|
|||
processor = Compose(pipelines)
|
||||
return processor
|
||||
|
||||
def _build_model(self):
|
||||
# Use mmdet model
|
||||
dynamic_adapt_for_mmlab(self.cfg)
|
||||
if 'vison_pretrained' in self.cfg.model:
|
||||
self.cfg.model.vison_pretrained = None
|
||||
if 'text_pretrained' in self.cfg.model:
|
||||
self.cfg.model.text_pretrained = None
|
||||
|
||||
model = build_model(self.cfg.model)
|
||||
# remove adapt for mmdet to avoid conflict using mmdet models
|
||||
remove_adapt_for_mmlab(self.cfg)
|
||||
return model
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
"""Return top-k results."""
|
||||
output_prob = inputs['prob'].data.cpu()
|
||||
|
|
Loading…
Reference in New Issue