add table_attribute_code
parent
6b218caf1b
commit
18acc25f18
|
@ -0,0 +1,35 @@
|
|||
Global:
|
||||
infer_imgs: "images/PULC/table_attribute/val_3610.jpg"
|
||||
inference_model_dir: "./models/table_attribute_infer"
|
||||
batch_size: 1
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
cpu_num_threads: 10
|
||||
benchmark: False
|
||||
use_fp16: False
|
||||
ir_optim: True
|
||||
use_tensorrt: False
|
||||
gpu_mem: 8000
|
||||
enable_profile: False
|
||||
|
||||
PreProcess:
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: [224, 224]
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
channel_num: 3
|
||||
- ToCHWImage:
|
||||
|
||||
PostProcess:
|
||||
main_indicator: TableAttribute
|
||||
TableAttribute:
|
||||
source_threshold: 0.5
|
||||
number_threshold: 0.5
|
||||
color_threshold: 0.5
|
||||
clarity_threshold : 0.5
|
||||
obstruction_threshold: 0.5
|
||||
angle_threshold: 0.5
|
Binary file not shown.
After Width: | Height: | Size: 85 KiB |
Binary file not shown.
After Width: | Height: | Size: 54 KiB |
|
@ -320,3 +320,42 @@ class VehicleAttribute(object):
|
|||
).astype(np.int8).tolist()
|
||||
batch_res.append({"attributes": label_res, "output": pred_res})
|
||||
return batch_res
|
||||
|
||||
|
||||
|
||||
class TableAttribute(object):
|
||||
def __init__(self,
|
||||
source_threshold=0.5,
|
||||
number_threshold=0.5,
|
||||
color_threshold=0.5,
|
||||
clarity_threshold=0.5,
|
||||
obstruction_threshold=0.5,
|
||||
angle_threshold=0.5,
|
||||
):
|
||||
self.source_threshold = source_threshold
|
||||
self.number_threshold = number_threshold
|
||||
self.color_threshold = color_threshold
|
||||
self.clarity_threshold = clarity_threshold
|
||||
self.obstruction_threshold = obstruction_threshold
|
||||
self.angle_threshold = angle_threshold
|
||||
|
||||
def __call__(self, batch_preds, file_names=None):
|
||||
# postprocess output of predictor
|
||||
batch_res = []
|
||||
for res in batch_preds:
|
||||
res = res.tolist()
|
||||
label_res = []
|
||||
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
|
||||
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
|
||||
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor'
|
||||
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
|
||||
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles'
|
||||
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted'
|
||||
|
||||
label_res = [source, number, color, clarity, obstruction, angle]
|
||||
|
||||
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold]
|
||||
pred_res = (np.array(res) > np.array(threshold_list)
|
||||
).astype(np.int8).tolist()
|
||||
batch_res.append({"attributes": label_res, "output": pred_res})
|
||||
return batch_res
|
||||
|
|
|
@ -11,17 +11,21 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from paddleclas.deploy.utils import logger, config
|
||||
from paddleclas.deploy.utils.predictor import Predictor
|
||||
from paddleclas.deploy.utils.get_image_list import get_image_list
|
||||
from paddleclas.deploy.python.preprocess import create_operators
|
||||
from paddleclas.deploy.python.postprocess import build_postprocess
|
||||
from utils import logger
|
||||
from utils import config
|
||||
from utils.predictor import Predictor
|
||||
from utils.get_image_list import get_image_list
|
||||
from python.preprocess import create_operators
|
||||
from python.postprocess import build_postprocess
|
||||
|
||||
|
||||
class ClsPredictor(Predictor):
|
||||
|
@ -136,7 +140,7 @@ def main(config):
|
|||
for number, result_dict in enumerate(batch_results):
|
||||
if "PersonAttribute" in config[
|
||||
"PostProcess"] or "VehicleAttribute" in config[
|
||||
"PostProcess"]:
|
||||
"PostProcess"] or "TableAttribute" in config["PostProcess"]:
|
||||
filename = batch_names[number]
|
||||
print("{}:\t {}".format(filename, result_dict))
|
||||
else:
|
||||
|
|
|
@ -191,7 +191,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf
|
|||
PULC_MODELS = [
|
||||
"car_exists", "language_classification", "person_attribute",
|
||||
"person_exists", "safety_helmet", "text_image_orientation",
|
||||
"textline_orientation", "traffic_sign", "vehicle_attribute"
|
||||
"textline_orientation", "traffic_sign", "vehicle_attribute",
|
||||
"table_attribute"
|
||||
]
|
||||
|
||||
|
||||
|
@ -271,7 +272,25 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
|||
if "type_threshold" in kwargs and kwargs["type_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
|
||||
"type_threshold"]
|
||||
|
||||
if "TableAttribute" in cfg.PostProcess:
|
||||
if "source_threshold" in kwargs and kwargs["source_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"source_threshold"]
|
||||
if "number_threshold" in kwargs and kwargs["number_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"number_threshold"]
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"clarity_threshold"]
|
||||
if "obstruction_threshold" in kwargs and kwargs["obstruction_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"obstruction_threshold"]
|
||||
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"angle_threshold"]
|
||||
if "save_dir" in kwargs and kwargs["save_dir"]:
|
||||
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
|
||||
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 20
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
use_multilabel: True
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "PPLCNet_x1_0"
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
class_num: 6
|
||||
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MultiLabelLoss:
|
||||
weight: 1.0
|
||||
weight_ratio: True
|
||||
size_sum: True
|
||||
Eval:
|
||||
- MultiLabelLoss:
|
||||
weight: 1.0
|
||||
weight_ratio: True
|
||||
size_sum: True
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.01
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: MultiLabelDataset
|
||||
image_root: "dataset/table_attribute/"
|
||||
cls_label_path: "dataset/table_attribute/train_list.txt"
|
||||
label_ratio: True
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [224, 224]
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
Eval:
|
||||
dataset:
|
||||
name: MultiLabelDataset
|
||||
image_root: "dataset/table_attribute/"
|
||||
cls_label_path: "dataset/table_attribute/val_list.txt"
|
||||
label_ratio: True
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [224, 224]
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: deploy/images/PULC/table_attribute/val_3610.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [224, 224]
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: TableAttribute
|
||||
source_threshold: 0.5
|
||||
number_threshold: 0.5
|
||||
color_threshold: 0.5
|
||||
clarity_threshold : 0.5
|
||||
obstruction_threshold: 0.5
|
||||
angle_threshold: 0.5
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- ATTRMetric:
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ from . import topk, threshoutput
|
|||
|
||||
from .topk import Topk, MultiLabelTopk
|
||||
from .threshoutput import ThreshOutput
|
||||
from .attr_rec import VehicleAttribute, PersonAttribute
|
||||
from .attr_rec import VehicleAttribute, PersonAttribute, TableAttribute
|
||||
|
||||
|
||||
def build_postprocess(config):
|
||||
|
|
|
@ -171,3 +171,47 @@ class PersonAttribute(object):
|
|||
batch_res.append({"attributes": label_res, "output": pred_res})
|
||||
return batch_res
|
||||
|
||||
|
||||
class TableAttribute(object):
|
||||
def __init__(self,
|
||||
source_threshold=0.5,
|
||||
number_threshold=0.5,
|
||||
color_threshold=0.5,
|
||||
clarity_threshold=0.5,
|
||||
obstruction_threshold=0.5,
|
||||
angle_threshold=0.5,
|
||||
):
|
||||
self.source_threshold = source_threshold
|
||||
self.number_threshold = number_threshold
|
||||
self.color_threshold = color_threshold
|
||||
self.clarity_threshold = clarity_threshold
|
||||
self.obstruction_threshold = obstruction_threshold
|
||||
self.angle_threshold = angle_threshold
|
||||
|
||||
def __call__(self, x, file_names=None):
|
||||
if isinstance(x, dict):
|
||||
x = x['logits']
|
||||
assert isinstance(x, paddle.Tensor)
|
||||
if file_names is not None:
|
||||
assert x.shape[0] == len(file_names)
|
||||
x = F.sigmoid(x).numpy()
|
||||
# postprocess output of predictor
|
||||
batch_res = []
|
||||
for idx, res in enumerate(x):
|
||||
res = res.tolist()
|
||||
label_res = []
|
||||
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
|
||||
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
|
||||
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor'
|
||||
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
|
||||
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles'
|
||||
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted'
|
||||
|
||||
label_res = [source, number, color, clarity, obstruction, angle]
|
||||
|
||||
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold]
|
||||
pred_res = (np.array(res) > np.array(threshold_list)
|
||||
).astype(np.int8).tolist()
|
||||
batch_res.append({"attributes": label_res, "output": pred_res, "file_name": file_names[idx]})
|
||||
return batch_res
|
||||
|
||||
|
|
Loading…
Reference in New Issue