diff --git a/easycv/models/video_recognition/ClipBertTwoStream.py b/easycv/models/video_recognition/ClipBertTwoStream.py index 832e3cf7..b43fb98e 100644 --- a/easycv/models/video_recognition/ClipBertTwoStream.py +++ b/easycv/models/video_recognition/ClipBertTwoStream.py @@ -28,8 +28,10 @@ class ClipBertTwoStream(BaseModel): self.test_cfg = test_cfg self.multi_class = multi_class - self.vison_pretrained = get_checkpoint(vison_pretrained) - self.text_pretrained = get_checkpoint(text_pretrained) + self.vison_pretrained = get_checkpoint( + vison_pretrained) if vison_pretrained else vison_pretrained + self.text_pretrained = get_checkpoint( + text_pretrained) if text_pretrained else text_pretrained loss_cls = dict(type='CrossEntropyLoss') if not multi_class else dict( type='AsymmetricLoss') self.loss_cls = builder.build_loss(loss_cls) diff --git a/easycv/models/video_recognition/recognizer3d.py b/easycv/models/video_recognition/recognizer3d.py index 40052247..6bbd5f62 100644 --- a/easycv/models/video_recognition/recognizer3d.py +++ b/easycv/models/video_recognition/recognizer3d.py @@ -30,7 +30,7 @@ class Recognizer3D(BaseModel): self.train_cfg = train_cfg self.test_cfg = test_cfg self.pretrained = get_checkpoint( - pretrained) if pretrained != None else pretrained + pretrained) if pretrained else pretrained self.activate_fn = nn.Softmax(dim=1) # aux_info is the list of tensor names beyond 'imgs' and 'label' which diff --git a/easycv/predictors/video_classifier.py b/easycv/predictors/video_classifier.py index 1e5de6bc..8b7abc60 100644 --- a/easycv/predictors/video_classifier.py +++ b/easycv/predictors/video_classifier.py @@ -8,7 +8,10 @@ from PIL import Image, ImageFile from easycv.datasets.registry import PIPELINES from easycv.file import io from easycv.framework.errors import ValueError +from easycv.models.builder import build_model from easycv.utils.misc import deprecated +from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab, + remove_adapt_for_mmlab) from easycv.utils.registry import build_from_cfg from .base import Predictor, PredictorV2 from .builder import PREDICTORS @@ -100,11 +103,16 @@ class VideoClassificationPredictor(PredictorV2): """ if self.pipelines is not None: pipelines = self.pipelines - else: + elif 'test_pipeline' in self.cfg: pipelines = self.cfg.get('test_pipeline', []) + else: + pipelines = self.cfg.get('val_pipeline', []) for idx, pipeline in enumerate(pipelines): if pipeline['type'] == 'Collect' and 'label' in pipeline['keys']: pipeline['keys'].remove('label') + if pipeline['type'] == 'VideoToTensor' and 'label' in pipeline[ + 'keys']: + pipeline['keys'].remove('label') pipelines[idx] = pipeline pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines] @@ -113,6 +121,19 @@ class VideoClassificationPredictor(PredictorV2): processor = Compose(pipelines) return processor + def _build_model(self): + # Use mmdet model + dynamic_adapt_for_mmlab(self.cfg) + if 'vison_pretrained' in self.cfg.model: + self.cfg.model.vison_pretrained = None + if 'text_pretrained' in self.cfg.model: + self.cfg.model.text_pretrained = None + + model = build_model(self.cfg.model) + # remove adapt for mmdet to avoid conflict using mmdet models + remove_adapt_for_mmlab(self.cfg) + return model + def postprocess(self, inputs, *args, **kwargs): """Return top-k results.""" output_prob = inputs['prob'].data.cpu()