diff --git a/mmocr/models/textdet/heads/base.py b/mmocr/models/textdet/heads/base.py index 75722b26..82dee4df 100644 --- a/mmocr/models/textdet/heads/base.py +++ b/mmocr/models/textdet/heads/base.py @@ -108,7 +108,7 @@ class BaseTextDetHead(BaseModule): outs = self(x, data_samples) losses = self.module_loss(outs, data_samples) - predictions = self.postprocessor(outs, data_samples) + predictions = self.postprocessor(outs, data_samples, self.training) return losses, predictions def predict(self, x: torch.Tensor, diff --git a/projects/ABCNet/abcnet/model/__init__.py b/projects/ABCNet/abcnet/model/__init__.py index 34d1c628..75997be0 100644 --- a/projects/ABCNet/abcnet/model/__init__.py +++ b/projects/ABCNet/abcnet/model/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .abcnet import ABCNet from .abcnet_det_head import ABCNetDetHead +from .abcnet_det_module_loss import ABCNetDetModuleLoss from .abcnet_det_postprocessor import ABCNetDetPostprocessor from .abcnet_postprocessor import ABCNetPostprocessor from .abcnet_rec import ABCNetRec @@ -8,10 +9,11 @@ from .abcnet_rec_backbone import ABCNetRecBackbone from .abcnet_rec_decoder import ABCNetRecDecoder from .abcnet_rec_encoder import ABCNetRecEncoder from .bezier_roi_extractor import BezierRoIExtractor -from .only_rec_roi_head import OnlyRecRoIHead +from .rec_roi_head import RecRoIHead __all__ = [ 'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone', 'ABCNetRecDecoder', 'ABCNetRecEncoder', 'ABCNet', 'ABCNetRec', - 'BezierRoIExtractor', 'OnlyRecRoIHead', 'ABCNetPostprocessor' + 'BezierRoIExtractor', 'RecRoIHead', 'ABCNetPostprocessor', + 'ABCNetDetModuleLoss' ] diff --git a/projects/ABCNet/abcnet/model/abcnet_det_module_loss.py b/projects/ABCNet/abcnet/model/abcnet_det_module_loss.py new file mode 100644 index 00000000..a8becc48 --- /dev/null +++ b/projects/ABCNet/abcnet/model/abcnet_det_module_loss.py @@ -0,0 +1,359 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +from mmdet.models.task_modules.prior_generators import MlvlPointGenerator +from mmdet.models.utils import multi_apply +from mmdet.utils import reduce_mean +from torch import Tensor + +from mmocr.models.textdet.module_losses.base import BaseTextDetModuleLoss +from mmocr.registry import MODELS, TASK_UTILS +from mmocr.structures import TextDetDataSample +from mmocr.utils import ConfigType, DetSampleList, RangeType +from ..utils import poly2bezier + +INF = 1e8 + + +@MODELS.register_module() +class ABCNetDetModuleLoss(BaseTextDetModuleLoss): + # TODO add docs + + def __init__( + self, + num_classes: int = 1, + bbox_coder: ConfigType = dict(type='mmdet.DistancePointBBoxCoder'), + regress_ranges: RangeType = ((-1, 64), (64, 128), (128, 256), + (256, 512), (512, INF)), + strides: List[int] = (8, 16, 32, 64, 128), + center_sampling: bool = True, + center_sample_radius: float = 1.5, + norm_on_bbox: bool = True, + loss_cls: ConfigType = dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox: ConfigType = dict(type='mmdet.GIoULoss', loss_weight=1.0), + loss_centerness: ConfigType = dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bezier: ConfigType = dict( + type='mmdet.SmoothL1Loss', reduction='mean', loss_weight=1.0) + ) -> None: + super().__init__() + self.num_classes = num_classes + self.strides = strides + self.prior_generator = MlvlPointGenerator(strides) + self.regress_ranges = regress_ranges + self.center_sampling = center_sampling + self.center_sample_radius = center_sample_radius + self.norm_on_bbox = norm_on_bbox + self.loss_centerness = MODELS.build(loss_centerness) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.loss_bezier = MODELS.build(loss_bezier) + self.bbox_coder = TASK_UTILS.build(bbox_coder) + use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + def forward(self, inputs: Tuple[Tensor], + data_samples: DetSampleList) -> Dict: + """Compute ABCNet loss. + + Args: + inputs (tuple(tensor)): Raw predictions from model, containing + ``cls_scores``, ``bbox_preds``, ``beizer_preds`` and + ``centernesses``. + Each is a tensor of shape :math:`(N, H, W)`. + data_samples (list[TextDetDataSample]): The data samples. + + Returns: + dict: The dict for abcnet-det losses with loss_cls, loss_bbox, + loss_centerness and loss_bezier. + """ + cls_scores, bbox_preds, centernesses, beizer_preds = inputs + assert len(cls_scores) == len(bbox_preds) == len(centernesses) == len( + beizer_preds) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + labels, bbox_targets, bezier_targets = self.get_targets( + all_level_points, data_samples) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores, bbox_preds and centerness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_centerness = [ + centerness.permute(0, 2, 3, 1).reshape(-1) + for centerness in centernesses + ] + flatten_bezier_preds = [ + bezier_pred.permute(0, 2, 3, 1).reshape(-1, 16) + for bezier_pred in beizer_preds + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_centerness = torch.cat(flatten_centerness) + flatten_bezier_preds = torch.cat(flatten_bezier_preds) + flatten_labels = torch.cat(labels) + flatten_bbox_targets = torch.cat(bbox_targets) + flatten_bezier_targets = torch.cat(bezier_targets) + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((flatten_labels >= 0) + & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) + num_pos = torch.tensor( + len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) + num_pos = max(reduce_mean(num_pos), 1.0) + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, avg_factor=num_pos) + + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_centerness = flatten_centerness[pos_inds] + pos_bezier_preds = flatten_bezier_preds[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_centerness_targets = self.centerness_target(pos_bbox_targets) + pos_bezier_targets = flatten_bezier_targets[pos_inds] + # centerness weighted iou loss + centerness_denorm = max( + reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) + + if len(pos_inds) > 0: + pos_points = flatten_points[pos_inds] + pos_decoded_bbox_preds = self.bbox_coder.decode( + pos_points, pos_bbox_preds) + pos_decoded_target_preds = self.bbox_coder.decode( + pos_points, pos_bbox_targets) + loss_bbox = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + weight=pos_centerness_targets, + avg_factor=centerness_denorm) + loss_centerness = self.loss_centerness( + pos_centerness, pos_centerness_targets, avg_factor=num_pos) + loss_bezier = self.loss_bezier( + pos_bezier_preds, + pos_bezier_targets, + weight=pos_centerness_targets[:, None], + avg_factor=centerness_denorm) + else: + loss_bbox = pos_bbox_preds.sum() + loss_centerness = pos_centerness.sum() + loss_bezier = pos_bezier_preds.sum() + + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_centerness=loss_centerness, + loss_bezier=loss_bezier) + + def get_targets(self, points: List[Tensor], data_samples: DetSampleList + ) -> Tuple[List[Tensor], List[Tensor]]: + """Compute regression, classification and centerness targets for points + in multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + data_samples: Batch of data samples. Each data sample contains + a gt_instance, which usually includes bboxes and labels + attributes. + + Returns: + tuple: Targets of each level. + + - concat_lvl_labels (list[Tensor]): Labels of each level. + - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ + level. + """ + assert len(points) == len(self.regress_ranges) + num_levels = len(points) + # expand regress ranges to align with points + expanded_regress_ranges = [ + points[i].new_tensor(self.regress_ranges[i])[None].expand_as( + points[i]) for i in range(num_levels) + ] + # concat all levels points and regress ranges + concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) + concat_points = torch.cat(points, dim=0) + + # the number of points per img, per lvl + num_points = [center.size(0) for center in points] + + # get labels and bbox_targets of each image + labels_list, bbox_targets_list, bezier_targets_list = multi_apply( + self._get_targets_single, + data_samples, + points=concat_points, + regress_ranges=concat_regress_ranges, + num_points_per_lvl=num_points) + + # split to per img, per level + labels_list = [labels.split(num_points, 0) for labels in labels_list] + bbox_targets_list = [ + bbox_targets.split(num_points, 0) + for bbox_targets in bbox_targets_list + ] + bezier_targets_list = [ + bezier_targets.split(num_points, 0) + for bezier_targets in bezier_targets_list + ] + # concat per level image + concat_lvl_labels = [] + concat_lvl_bbox_targets = [] + concat_lvl_bezier_targets = [] + for i in range(num_levels): + concat_lvl_labels.append( + torch.cat([labels[i] for labels in labels_list])) + bbox_targets = torch.cat( + [bbox_targets[i] for bbox_targets in bbox_targets_list]) + bezier_targets = torch.cat( + [bezier_targets[i] for bezier_targets in bezier_targets_list]) + if self.norm_on_bbox: + bbox_targets = bbox_targets / self.strides[i] + bezier_targets = bezier_targets / self.strides[i] + concat_lvl_bbox_targets.append(bbox_targets) + concat_lvl_bezier_targets.append(bezier_targets) + return (concat_lvl_labels, concat_lvl_bbox_targets, + concat_lvl_bezier_targets) + + def _get_targets_single(self, data_sample: TextDetDataSample, + points: Tensor, regress_ranges: Tensor, + num_points_per_lvl: List[int] + ) -> Tuple[Tensor, Tensor, Tensor]: + """Compute regression and classification targets for a single image.""" + num_points = points.size(0) + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[~gt_instances.ignored] + num_gts = len(gt_instances) + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + data_sample.gt_instances = gt_instances + polygons = gt_instances.polygons + beziers = gt_bboxes.new([poly2bezier(poly) for poly in polygons]) + gt_instances.beziers = beziers + if num_gts == 0: + return gt_labels.new_full((num_points,), self.num_classes), \ + gt_bboxes.new_zeros((num_points, 4)), \ + gt_bboxes.new_zeros((num_points, 16)) + + areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) + # TODO: figure out why these two are different + # areas = areas[None].expand(num_points, num_gts) + areas = areas[None].repeat(num_points, 1) + regress_ranges = regress_ranges[:, None, :].expand( + num_points, num_gts, 2) + gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) + xs, ys = points[:, 0], points[:, 1] + xs = xs[:, None].expand(num_points, num_gts) + ys = ys[:, None].expand(num_points, num_gts) + + left = xs - gt_bboxes[..., 0] + right = gt_bboxes[..., 2] - xs + top = ys - gt_bboxes[..., 1] + bottom = gt_bboxes[..., 3] - ys + bbox_targets = torch.stack((left, top, right, bottom), -1) + + beziers = beziers.reshape(-1, 8, + 2)[None].expand(num_points, num_gts, 8, 2) + beziers_left = beziers[..., 0] - xs[..., None] + beziers_right = beziers[..., 1] - ys[..., None] + bezier_targets = torch.stack((beziers_left, beziers_right), dim=-1) + bezier_targets = bezier_targets.view(num_points, num_gts, 16) + if self.center_sampling: + # condition1: inside a `center bbox` + radius = self.center_sample_radius + center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 + center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 + center_gts = torch.zeros_like(gt_bboxes) + stride = center_xs.new_zeros(center_xs.shape) + + # project the points on current lvl back to the `original` sizes + lvl_begin = 0 + for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): + lvl_end = lvl_begin + num_points_lvl + stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius + lvl_begin = lvl_end + + x_mins = center_xs - stride + y_mins = center_ys - stride + x_maxs = center_xs + stride + y_maxs = center_ys + stride + center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], + x_mins, gt_bboxes[..., 0]) + center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], + y_mins, gt_bboxes[..., 1]) + center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], + gt_bboxes[..., 2], x_maxs) + center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], + gt_bboxes[..., 3], y_maxs) + + cb_dist_left = xs - center_gts[..., 0] + cb_dist_right = center_gts[..., 2] - xs + cb_dist_top = ys - center_gts[..., 1] + cb_dist_bottom = center_gts[..., 3] - ys + center_bbox = torch.stack( + (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) + inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 + else: + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 + + # condition2: limit the regression range for each location + max_regress_distance = bbox_targets.max(-1)[0] + inside_regress_range = ( + (max_regress_distance >= regress_ranges[..., 0]) + & (max_regress_distance <= regress_ranges[..., 1])) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + areas[inside_gt_bbox_mask == 0] = INF + areas[inside_regress_range == 0] = INF + min_area, min_area_inds = areas.min(dim=1) + + labels = gt_labels[min_area_inds] + labels[min_area == INF] = self.num_classes # set as BG + bbox_targets = bbox_targets[range(num_points), min_area_inds] + bezier_targets = bezier_targets[range(num_points), min_area_inds] + + return labels, bbox_targets, bezier_targets + + def centerness_target(self, pos_bbox_targets: Tensor) -> Tensor: + """Compute centerness targets. + + Args: + pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape + (num_pos, 4) + + Returns: + Tensor: Centerness target. + """ + # only calculate pos centerness targets, otherwise there may be nan + left_right = pos_bbox_targets[:, [0, 2]] + top_bottom = pos_bbox_targets[:, [1, 3]] + if len(left_right) == 0: + centerness_targets = left_right[..., 0] + else: + centerness_targets = ( + left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + return torch.sqrt(centerness_targets) diff --git a/projects/ABCNet/abcnet/model/only_rec_roi_head.py b/projects/ABCNet/abcnet/model/rec_roi_head.py similarity index 80% rename from projects/ABCNet/abcnet/model/only_rec_roi_head.py rename to projects/ABCNet/abcnet/model/rec_roi_head.py index fec59419..13813bb8 100644 --- a/projects/ABCNet/abcnet/model/only_rec_roi_head.py +++ b/projects/ABCNet/abcnet/model/rec_roi_head.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple +from mmengine.structures import LabelData from torch import Tensor from mmocr.registry import MODELS, TASK_UTILS @@ -10,7 +11,7 @@ from .base_roi_head import BaseRoIHead @MODELS.register_module() -class OnlyRecRoIHead(BaseRoIHead): +class RecRoIHead(BaseRoIHead): """Simplest base roi head including one bbox head and one mask head.""" def __init__(self, @@ -39,8 +40,17 @@ class OnlyRecRoIHead(BaseRoIHead): Returns: dict[str, Tensor]: A dictionary of loss components """ + proposals = [ + ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples + ] - pass + proposals = [p for p in proposals if len(p) > 0] + bbox_feats = self.roi_extractor(inputs, proposals) + rec_data_samples = [ + TextRecogDataSample(gt_text=LabelData(item=text)) + for proposal in proposals for text in proposal.texts + ] + return self.rec_head.loss(bbox_feats, rec_data_samples) def predict(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> RecSampleList: diff --git a/projects/ABCNet/abcnet/model/two_stage_text_spotting.py b/projects/ABCNet/abcnet/model/two_stage_text_spotting.py index 13a7c9f3..4a9bd8ef 100644 --- a/projects/ABCNet/abcnet/model/two_stage_text_spotting.py +++ b/projects/ABCNet/abcnet/model/two_stage_text_spotting.py @@ -70,7 +70,14 @@ class TwoStageTextSpotter(BaseTextDetector): def loss(self, inputs: torch.Tensor, data_samples: OptDetSampleList) -> Dict: - pass + losses = dict() + inputs = self.extract_feat(inputs) + det_loss, data_samples = self.det_head.loss_and_predict( + inputs, data_samples) + roi_losses = self.roi_head.loss(inputs, data_samples) + losses.update(det_loss) + losses.update(roi_losses) + return losses def predict(self, inputs: torch.Tensor, data_samples: OptDetSampleList) -> OptDetSampleList: diff --git a/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py b/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py new file mode 100644 index 00000000..431c48ff --- /dev/null +++ b/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py @@ -0,0 +1,12 @@ +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(type='value', clip_value=1)) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +# learning policy +param_scheduler = [ + dict(type='LinearLR', end=1000, start_factor=0.001, by_epoch=False), +] diff --git a/projects/ABCNet/config/abcnet/_base_abcnet-det_resnet50_fpn.py b/projects/ABCNet/config/abcnet/_base_abcnet_resnet50_fpn.py similarity index 64% rename from projects/ABCNet/config/abcnet/_base_abcnet-det_resnet50_fpn.py rename to projects/ABCNet/config/abcnet/_base_abcnet_resnet50_fpn.py index 68891db3..62630c36 100644 --- a/projects/ABCNet/config/abcnet/_base_abcnet-det_resnet50_fpn.py +++ b/projects/ABCNet/config/abcnet/_base_abcnet_resnet50_fpn.py @@ -67,21 +67,37 @@ model = dict( std=0.01, bias=-4.59511985013459), # -log((1-p)/p) where p=0.01 ), - module_loss=None, + module_loss=dict( + type='ABCNetDetModuleLoss', + num_classes=num_classes, + strides=strides, + center_sampling=True, + center_sample_radius=1.5, + bbox_coder=bbox_coder, + norm_on_bbox=norm_on_bbox, + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=use_sigmoid_cls, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0), + loss_centerness=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0)), postprocessor=dict( type='ABCNetDetPostprocessor', - # rescale_fields=['polygons', 'bboxes'], use_sigmoid_cls=use_sigmoid_cls, strides=[8, 16, 32, 64, 128], bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'), with_bezier=True, test_cfg=dict( - # rescale_fields=['polygon', 'bboxes', 'bezier'], nms_pre=1000, nms=dict(type='nms', iou_threshold=0.5), score_thr=0.3))), roi_head=dict( - type='OnlyRecRoIHead', + type='RecRoIHead', roi_extractor=dict( type='BezierRoIExtractor', roi_layer=dict( @@ -95,7 +111,14 @@ model = dict( decoder=dict( type='ABCNetRecDecoder', dictionary=dictionary, - postprocessor=dict(type='AttentionPostprocessor'), + postprocessor=dict( + type='AttentionPostprocessor', + ignore_chars=['padding', 'unknown']), + module_loss=dict( + type='CEModuleLoss', + ignore_first_char=False, + ignore_char=-1, + reduction='mean'), max_seq_len=25))), postprocessor=dict( type='ABCNetPostprocessor', @@ -118,3 +141,32 @@ test_pipeline = [ type='PackTextDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + color_type='color_ignore_orientation'), + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True, + with_text=True), + dict(type='RemoveIgnored'), + dict(type='RandomCrop', min_side_ratio=0.1), + dict( + type='RandomRotate', + max_angle=30, + pad_with_fixed_color=True, + use_canvas=True), + dict( + type='RandomChoiceResize', + scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900), + (1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900), + (1492, 2900)], + keep_ratio=True), + dict( + type='PackTextDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) +] diff --git a/projects/ABCNet/config/abcnet/abcnet_resnet50_fpn.py b/projects/ABCNet/config/abcnet/abcnet_resnet50_fpn.py deleted file mode 100644 index 2f4515a9..00000000 --- a/projects/ABCNet/config/abcnet/abcnet_resnet50_fpn.py +++ /dev/null @@ -1,24 +0,0 @@ -_base_ = [ - '_base_abcnet-det_resnet50_fpn.py', - '../_base_/datasets/icdar2015.py', - '../_base_/default_runtime.py', -] - -# dataset settings -icdar2015_textspotting_test = _base_.icdar2015_textspotting_test -icdar2015_textspotting_test.pipeline = _base_.test_pipeline - -val_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=icdar2015_textspotting_test) - -test_dataloader = val_dataloader - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -custom_imports = dict( - imports=['projects.ABCNet.abcnet'], allow_failed_imports=False) diff --git a/projects/ABCNet/config/abcnet/abcnet_resnet50_fpn_500e_icdar2015.py b/projects/ABCNet/config/abcnet/abcnet_resnet50_fpn_500e_icdar2015.py new file mode 100644 index 00000000..424a3525 --- /dev/null +++ b/projects/ABCNet/config/abcnet/abcnet_resnet50_fpn_500e_icdar2015.py @@ -0,0 +1,37 @@ +_base_ = [ + '_base_abcnet_resnet50_fpn.py', + '../_base_/datasets/icdar2015.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_sgd_500e.py', +] + +# dataset settings +icdar2015_textspotting_train = _base_.icdar2015_textspotting_train +icdar2015_textspotting_train.pipeline = _base_.train_pipeline +icdar2015_textspotting_test = _base_.icdar2015_textspotting_test +icdar2015_textspotting_test.pipeline = _base_.test_pipeline + +train_dataloader = dict( + batch_size=2, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=icdar2015_textspotting_train) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=icdar2015_textspotting_test) + +test_dataloader = val_dataloader + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +custom_imports = dict(imports=['abcnet'], allow_failed_imports=False) + +load_from = 'https://download.openmmlab.com/mmocr/textspotting/abcnet/abcnet_resnet50_fpn_500e_icdar2015/abcnet_resnet50_fpn_pretrain-d060636c.pth' # noqa + +find_unused_parameters = True