parent
2751cb3a11
commit
e49e491417
|
@ -0,0 +1,144 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/svtr_large/
|
||||
save_epoch_step: 10
|
||||
# evaluation is run every 2000 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
max_text_length: &max_text_length 40
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_svtr_large.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
epsilon: 1.0e-08
|
||||
weight_decay: 0.05
|
||||
no_weight_decay_name: norm pos_embed char_node_embed pos_node_embed char_pos_embed vis_pos_embed
|
||||
one_dim_param_no_weight_decay: true
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.00025 # 8gpus 64bs
|
||||
warmup_epoch: 5
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SVTR_LCNet
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: SVTRNet
|
||||
img_size:
|
||||
- 48
|
||||
- 320
|
||||
out_char_num: 40
|
||||
out_channels: 512
|
||||
patch_merging: Conv
|
||||
embed_dim: [192, 256, 512]
|
||||
depth: [6, 6, 9]
|
||||
num_heads: [6, 8, 16]
|
||||
mixer: ['Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
|
||||
local_mixer: [[5, 5], [5, 5], [5, 5]]
|
||||
last_stage: False
|
||||
prenorm: True
|
||||
Head:
|
||||
name: MultiHead
|
||||
use_pool: true
|
||||
use_pos: true
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 256
|
||||
depth: 2
|
||||
hidden_dims: 256
|
||||
kernel_size: [1, 3]
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- NRTRHead:
|
||||
nrtr_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
|
||||
Loss:
|
||||
name: MultiLoss
|
||||
loss_config_list:
|
||||
- CTCLoss:
|
||||
- NRTRLoss:
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
ignore_space: true
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
ext_op_transform_idx: 1
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 64
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 4
|
|
@ -47,7 +47,7 @@ class RecAug(object):
|
|||
if h >= 20 and w >= 20:
|
||||
img = tia_distort(img, random.randint(3, 6))
|
||||
img = tia_stretch(img, random.randint(3, 6))
|
||||
img = tia_perspective(img)
|
||||
img = tia_perspective(img)
|
||||
|
||||
# bda
|
||||
data['image'] = img
|
||||
|
|
|
@ -24,6 +24,7 @@ def build_backbone(config, model_type):
|
|||
from .det_pp_lcnet import PPLCNet
|
||||
from .rec_lcnetv3 import PPLCNetV3
|
||||
from .rec_hgnet import PPHGNet_small
|
||||
from .rec_vit import ViT
|
||||
support_dict = [
|
||||
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
|
||||
"PPLCNetV3", "PPHGNet_small"
|
||||
|
@ -55,7 +56,7 @@ def build_backbone(config, model_type):
|
|||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
|
||||
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ'
|
||||
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ', 'ViT'
|
||||
]
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,258 @@
|
|||
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=.02)
|
||||
normal_ = Normal
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0., training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = paddle.to_tensor(1 - drop_prob)
|
||||
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Layer):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Layer):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Layer):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.dim = dim
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
qkv = paddle.reshape(self.qkv(x), (0, -1, 3, self.num_heads, self.dim //
|
||||
self.num_heads)).transpose((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = (q.matmul(k.transpose((0, 1, 3, 2))))
|
||||
attn = nn.functional.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-6,
|
||||
prenorm=True):
|
||||
super().__init__()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
|
||||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.mixer = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
|
||||
else:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp = Mlp(in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
self.prenorm = prenorm
|
||||
|
||||
def forward(self, x):
|
||||
if self.prenorm:
|
||||
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ViT(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=[32, 128],
|
||||
patch_size=[4,4],
|
||||
in_channels=3,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-6,
|
||||
act='nn.GELU',
|
||||
prenorm=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.out_channels = embed_dim
|
||||
self.prenorm = prenorm
|
||||
self.patch_embed = nn.Conv2D(in_channels, embed_dim, patch_size, patch_size, padding=(0, 0))
|
||||
self.pos_embed = self.create_parameter(
|
||||
shape=[1, 257, embed_dim], default_initializer=zeros_)
|
||||
self.add_parameter("pos_embed", self.pos_embed)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
dpr = np.linspace(0, drop_path_rate, depth)
|
||||
self.blocks1 = nn.LayerList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm) for i in range(depth)
|
||||
])
|
||||
if not prenorm:
|
||||
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D([1, 25])
|
||||
self.last_conv = nn.Conv2D(
|
||||
in_channels=embed_dim,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = nn.Dropout(p=0.1, mode="downscale_in_infer")
|
||||
|
||||
trunc_normal_(self.pos_embed)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
|
||||
x = x + self.pos_embed[:, 1:, :] #[:, :paddle.shape(x)[1], :]
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks1:
|
||||
x = blk(x)
|
||||
if not self.prenorm:
|
||||
x = self.norm(x)
|
||||
|
||||
x = self.avg_pool(x.transpose([0, 2, 1]).reshape(
|
||||
[0, self.embed_dim, -1, 25]))
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
|
@ -22,7 +22,7 @@ from paddle import ParamAttr
|
|||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR
|
||||
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR, trunc_normal_, zeros_
|
||||
from .rec_ctc_head import CTCHead
|
||||
from .rec_sar_head import SARHead
|
||||
from .rec_nrtr_head import Transformer
|
||||
|
@ -41,12 +41,28 @@ class FCTranspose(nn.Layer):
|
|||
else:
|
||||
return self.fc(x.transpose([0, 2, 1]))
|
||||
|
||||
class AddPos(nn.Layer):
|
||||
def __init__(self, dim, w):
|
||||
super().__init__()
|
||||
self.dec_pos_embed = self.create_parameter(
|
||||
shape=[1, w, dim], default_initializer=zeros_)
|
||||
self.add_parameter("dec_pos_embed", self.dec_pos_embed)
|
||||
trunc_normal_(self.dec_pos_embed)
|
||||
|
||||
def forward(self,x):
|
||||
x = x + self.dec_pos_embed[:, :paddle.shape(x)[1], :]
|
||||
return x
|
||||
|
||||
|
||||
class MultiHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels_list, **kwargs):
|
||||
super().__init__()
|
||||
self.head_list = kwargs.pop('head_list')
|
||||
|
||||
self.use_pool = kwargs.get('use_pool', False)
|
||||
self.use_pos = kwargs.get('use_pos', False)
|
||||
self.in_channels = in_channels
|
||||
if self.use_pool:
|
||||
self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0)
|
||||
self.gtc_head = 'sar'
|
||||
assert len(self.head_list) >= 2
|
||||
for idx, head_name in enumerate(self.head_list):
|
||||
|
@ -61,8 +77,13 @@ class MultiHead(nn.Layer):
|
|||
max_text_length = gtc_args.get('max_text_length', 25)
|
||||
nrtr_dim = gtc_args.get('nrtr_dim', 256)
|
||||
num_decoder_layers = gtc_args.get('num_decoder_layers', 4)
|
||||
self.before_gtc = nn.Sequential(
|
||||
if self.use_pos:
|
||||
self.before_gtc = nn.Sequential(
|
||||
nn.Flatten(2), FCTranspose(in_channels, nrtr_dim), AddPos(nrtr_dim, 80))
|
||||
else:
|
||||
self.before_gtc = nn.Sequential(
|
||||
nn.Flatten(2), FCTranspose(in_channels, nrtr_dim))
|
||||
|
||||
self.gtc_head = Transformer(
|
||||
d_model=nrtr_dim,
|
||||
nhead=nrtr_dim // 32,
|
||||
|
@ -88,7 +109,8 @@ class MultiHead(nn.Layer):
|
|||
'{} is not supported in MultiHead yet'.format(name))
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
|
||||
if self.use_pool:
|
||||
x = self.pool(x.reshape([0, 3, -1, self.in_channels]).transpose([0, 3, 1, 2]))
|
||||
ctc_encoder = self.ctc_encoder(x)
|
||||
ctc_out = self.ctc_head(ctc_encoder, targets)
|
||||
head_out = dict()
|
||||
|
|
|
@ -57,23 +57,35 @@ def _check_image_file(path):
|
|||
return any([path.lower().endswith(e) for e in img_end])
|
||||
|
||||
|
||||
def get_image_file_list(img_file):
|
||||
def get_image_file_list(img_file, infer_list=None):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
if infer_list and not os.path.exists(infer_list):
|
||||
raise Exception("not found infer list {}".format(infer_list))
|
||||
if infer_list:
|
||||
with open(infer_list, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
image_path = line.strip().split("\t")[0]
|
||||
image_path = os.path.join(img_file, image_path)
|
||||
imgs_lists.append(image_path)
|
||||
else:
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
if os.path.isfile(img_file) and _check_image_file(img_file):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
|
||||
if os.path.isfile(img_file) and _check_image_file(img_file):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
def binarize_img(img):
|
||||
if len(img.shape) == 3 and img.shape[2] == 3:
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
|
||||
|
|
|
@ -118,9 +118,11 @@ def main():
|
|||
os.makedirs(os.path.dirname(save_res_path))
|
||||
|
||||
model.eval()
|
||||
|
||||
|
||||
infer_imgs = config['Global']['infer_img']
|
||||
infer_list = config['Global'].get('infer_list', None)
|
||||
with open(save_res_path, "w") as fout:
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
for file in get_image_file_list(infer_imgs, infer_list=infer_list):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
with open(file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
|
Loading…
Reference in New Issue