adapt_videopredictor_to_eas (#268)

pull/267/head^2
yhq 2023-01-16 14:05:23 +08:00 committed by GitHub
parent 8379127388
commit 3953bd2bd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 4 deletions

View File

@ -28,8 +28,10 @@ class ClipBertTwoStream(BaseModel):
self.test_cfg = test_cfg
self.multi_class = multi_class
self.vison_pretrained = get_checkpoint(vison_pretrained)
self.text_pretrained = get_checkpoint(text_pretrained)
self.vison_pretrained = get_checkpoint(
vison_pretrained) if vison_pretrained else vison_pretrained
self.text_pretrained = get_checkpoint(
text_pretrained) if text_pretrained else text_pretrained
loss_cls = dict(type='CrossEntropyLoss') if not multi_class else dict(
type='AsymmetricLoss')
self.loss_cls = builder.build_loss(loss_cls)

View File

@ -30,7 +30,7 @@ class Recognizer3D(BaseModel):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.pretrained = get_checkpoint(
pretrained) if pretrained != None else pretrained
pretrained) if pretrained else pretrained
self.activate_fn = nn.Softmax(dim=1)
# aux_info is the list of tensor names beyond 'imgs' and 'label' which

View File

@ -8,7 +8,10 @@ from PIL import Image, ImageFile
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.models.builder import build_model
from easycv.utils.misc import deprecated
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab)
from easycv.utils.registry import build_from_cfg
from .base import Predictor, PredictorV2
from .builder import PREDICTORS
@ -100,11 +103,16 @@ class VideoClassificationPredictor(PredictorV2):
"""
if self.pipelines is not None:
pipelines = self.pipelines
else:
elif 'test_pipeline' in self.cfg:
pipelines = self.cfg.get('test_pipeline', [])
else:
pipelines = self.cfg.get('val_pipeline', [])
for idx, pipeline in enumerate(pipelines):
if pipeline['type'] == 'Collect' and 'label' in pipeline['keys']:
pipeline['keys'].remove('label')
if pipeline['type'] == 'VideoToTensor' and 'label' in pipeline[
'keys']:
pipeline['keys'].remove('label')
pipelines[idx] = pipeline
pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines]
@ -113,6 +121,19 @@ class VideoClassificationPredictor(PredictorV2):
processor = Compose(pipelines)
return processor
def _build_model(self):
# Use mmdet model
dynamic_adapt_for_mmlab(self.cfg)
if 'vison_pretrained' in self.cfg.model:
self.cfg.model.vison_pretrained = None
if 'text_pretrained' in self.cfg.model:
self.cfg.model.text_pretrained = None
model = build_model(self.cfg.model)
# remove adapt for mmdet to avoid conflict using mmdet models
remove_adapt_for_mmlab(self.cfg)
return model
def postprocess(self, inputs, *args, **kwargs):
"""Return top-k results."""
output_prob = inputs['prob'].data.cpu()