From 47f54304f541fc009a963c434898b9edfcfe23a1 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Fri, 10 Mar 2023 15:33:12 +0800 Subject: [PATCH] [Enhancement] Make lanms-neo optional (#1772) * [Enhancement] Make lanms-neo optional * fix * rm --- mmocr/models/textdet/heads/drrg_head.py | 9 ++++++++- mmocr/models/textdet/module_losses/drrg_module_loss.py | 9 ++++++++- requirements/readthedocs.txt | 1 - requirements/runtime.txt | 1 - requirements/tests.txt | 1 + 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/mmocr/models/textdet/heads/drrg_head.py b/mmocr/models/textdet/heads/drrg_head.py index 353dc2ec..14f70858 100644 --- a/mmocr/models/textdet/heads/drrg_head.py +++ b/mmocr/models/textdet/heads/drrg_head.py @@ -6,7 +6,11 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from lanms import merge_quadrangle_n9 as la_nms + +try: + from lanms import merge_quadrangle_n9 as la_nms +except ImportError: + la_nms = None from mmcv.ops import RoIAlignRotated from mmengine.model import BaseModule from numpy import ndarray @@ -838,6 +842,9 @@ class ProposalLocalGraphs: self.comp_shrink_ratio, self.comp_w_h_ratio) + if la_nms is None: + raise ImportError('lanms-neo is not installed, ' + 'please run "pip install lanms-neo==1.0.2".') text_comps = la_nms(text_comps, self.nms_thr) text_comp_mask = np.zeros(mask_sz) text_comp_boxes = text_comps[:, :8].reshape( diff --git a/mmocr/models/textdet/module_losses/drrg_module_loss.py b/mmocr/models/textdet/module_losses/drrg_module_loss.py index 51923ef0..a07fbb2a 100644 --- a/mmocr/models/textdet/module_losses/drrg_module_loss.py +++ b/mmocr/models/textdet/module_losses/drrg_module_loss.py @@ -4,7 +4,11 @@ from typing import Dict, List, Sequence, Tuple import cv2 import numpy as np import torch -from lanms import merge_quadrangle_n9 as la_nms + +try: + from lanms import merge_quadrangle_n9 as la_nms +except ImportError: + la_nms = None from mmcv.image import imrescale from mmdet.models.utils import multi_apply from numpy import ndarray @@ -447,6 +451,9 @@ class DRRGModuleLoss(TextSnakeModuleLoss): score = np.ones((text_comps.shape[0], 1), dtype=np.float32) text_comps = np.hstack([text_comps, score]) + if la_nms is None: + raise ImportError('lanms-neo is not installed, ' + 'please run "pip install lanms-neo==1.0.2".') text_comps = la_nms(text_comps, self.text_comp_nms_thr) if text_comps.shape[0] >= 1: diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt index 905a6a10..45edbc15 100644 --- a/requirements/readthedocs.txt +++ b/requirements/readthedocs.txt @@ -1,6 +1,5 @@ imgaug kwarray -lanms-neo==1.0.2 lmdb matplotlib mmcv>=2.0.0rc1 diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 02df3dcf..52a9eec3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,5 +1,4 @@ imgaug -lanms-neo==1.0.2 lmdb matplotlib numpy diff --git a/requirements/tests.txt b/requirements/tests.txt index 563fc468..19711e10 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -5,6 +5,7 @@ interrogate isort # Note: used for kwarray.group_items, this may be ported to mmcv in the future. kwarray +lanms-neo==1.0.2 parameterized pytest pytest-cov