Merge pull request #4545 from LDOUBLEV/test_v10
add refer to east and sast preprocesspull/4556/head
commit
357dcc1498
|
@ -90,7 +90,7 @@ Optimizer:
|
||||||
|
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: DistillationDBPostProcess
|
name: DistillationDBPostProcess
|
||||||
model_name: ["Student", "Student2"]
|
model_name: ["Student"]
|
||||||
key: head_out
|
key: head_out
|
||||||
thresh: 0.3
|
thresh: 0.3
|
||||||
box_thresh: 0.6
|
box_thresh: 0.6
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
#See the License for the specific language governing permissions and
|
#See the License for the specific language governing permissions and
|
||||||
#limitations under the License.
|
#limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refered from:
|
||||||
|
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain']
|
||||||
|
|
||||||
class EASTProcessTrain(object):
|
class EASTProcessTrain(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
image_shape = [512, 512],
|
image_shape=[512, 512],
|
||||||
background_ratio = 0.125,
|
background_ratio=0.125,
|
||||||
min_crop_side_ratio = 0.1,
|
min_crop_side_ratio=0.1,
|
||||||
min_text_size = 10,
|
min_text_size=10,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.input_size = image_shape[1]
|
self.input_size = image_shape[1]
|
||||||
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
||||||
|
@ -282,12 +285,7 @@ class EASTProcessTrain(object):
|
||||||
1.0 / max(min(poly_h, poly_w), 1.0)
|
1.0 / max(min(poly_h, poly_w), 1.0)
|
||||||
return score_map, geo_map, training_mask
|
return score_map, geo_map, training_mask
|
||||||
|
|
||||||
def crop_area(self,
|
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
|
||||||
im,
|
|
||||||
polys,
|
|
||||||
tags,
|
|
||||||
crop_background=False,
|
|
||||||
max_tries=50):
|
|
||||||
"""
|
"""
|
||||||
make random crop from the input image
|
make random crop from the input image
|
||||||
:param im:
|
:param im:
|
||||||
|
@ -435,5 +433,4 @@ class EASTProcessTrain(object):
|
||||||
data['score_map'] = score_map
|
data['score_map'] = score_map
|
||||||
data['geo_map'] = geo_map
|
data['geo_map'] = geo_map
|
||||||
data['training_mask'] = training_mask
|
data['training_mask'] = training_mask
|
||||||
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
|
return data
|
||||||
return data
|
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
#See the License for the specific language governing permissions and
|
#See the License for the specific language governing permissions and
|
||||||
#limitations under the License.
|
#limitations under the License.
|
||||||
|
"""
|
||||||
|
This part code is refered from:
|
||||||
|
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refered from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
@ -190,7 +193,8 @@ class DBPostProcess(object):
|
||||||
|
|
||||||
|
|
||||||
class DistillationDBPostProcess(object):
|
class DistillationDBPostProcess(object):
|
||||||
def __init__(self, model_name=["student"],
|
def __init__(self,
|
||||||
|
model_name=["student"],
|
||||||
key=None,
|
key=None,
|
||||||
thresh=0.3,
|
thresh=0.3,
|
||||||
box_thresh=0.6,
|
box_thresh=0.6,
|
||||||
|
@ -201,12 +205,13 @@ class DistillationDBPostProcess(object):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.key = key
|
self.key = key
|
||||||
self.post_process = DBPostProcess(thresh=thresh,
|
self.post_process = DBPostProcess(
|
||||||
box_thresh=box_thresh,
|
thresh=thresh,
|
||||||
max_candidates=max_candidates,
|
box_thresh=box_thresh,
|
||||||
unclip_ratio=unclip_ratio,
|
max_candidates=max_candidates,
|
||||||
use_dilation=use_dilation,
|
unclip_ratio=unclip_ratio,
|
||||||
score_mode=score_mode)
|
use_dilation=use_dilation,
|
||||||
|
score_mode=score_mode)
|
||||||
|
|
||||||
def __call__(self, predicts, shape_list):
|
def __call__(self, predicts, shape_list):
|
||||||
results = {}
|
results = {}
|
||||||
|
|
Loading…
Reference in New Issue