mirror of https://github.com/open-mmlab/mmocr.git
Add PyUpgrade pre-commit hook
parent
593d7529a3
commit
536dfdd4bd
|
@ -107,7 +107,7 @@ venv.bak/
|
|||
|
||||
# cython generated cpp
|
||||
!data/dict
|
||||
data/*
|
||||
data
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
|
|
|
@ -43,6 +43,11 @@ repos:
|
|||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.32.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py36-plus"]
|
||||
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
||||
rev: v0.2.0 # Use the ref you want to point at
|
||||
hooks:
|
||||
|
|
|
@ -27,7 +27,7 @@ author = 'OpenMMLab'
|
|||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../mmocr/version.py'
|
||||
with open(version_file, 'r') as f:
|
||||
with open(version_file) as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
__version__ = locals()['__version__']
|
||||
release = __version__
|
||||
|
|
|
@ -21,7 +21,7 @@ files = sorted(glob.glob('*_models.md'))
|
|||
stats = []
|
||||
|
||||
for f in files:
|
||||
with open(f, 'r') as content_file:
|
||||
with open(f) as content_file:
|
||||
content = content_file.read()
|
||||
|
||||
# Remove the blackquote notation from the paper link under the title
|
||||
|
@ -39,9 +39,8 @@ for f in files:
|
|||
exclude_expr = ''.join(f'(?!{s})' for s in exclude_papertype)
|
||||
expr = rf'<!-- \[{exclude_expr}([A-Z]+?)\] -->'\
|
||||
r'\s*\n.*?\btitle\s*=\s*{(.*?)}'
|
||||
papers = set(
|
||||
(papertype, titlecase.titlecase(paper.lower().strip()))
|
||||
for (papertype, paper) in re.findall(expr, content, re.DOTALL))
|
||||
papers = {(papertype, titlecase.titlecase(paper.lower().strip()))
|
||||
for (papertype, paper) in re.findall(expr, content, re.DOTALL)}
|
||||
print(papers)
|
||||
# paper links
|
||||
revcontent = '\n'.join(list(reversed(content.splitlines())))
|
||||
|
@ -56,13 +55,17 @@ for f in files:
|
|||
paperlist = '\n'.join(
|
||||
sorted(f' - [{t}] {paperlinks[x]}' for t, x in papers))
|
||||
# count configs
|
||||
configs = set(x.lower().strip()
|
||||
for x in re.findall(r'https.*configs/.*\.py', content))
|
||||
configs = {
|
||||
x.lower().strip()
|
||||
for x in re.findall(r'https.*configs/.*\.py', content)
|
||||
}
|
||||
|
||||
# count ckpts
|
||||
ckpts = set(x.lower().strip()
|
||||
for x in re.findall(r'https://download.*\.pth', content)
|
||||
if 'mmocr' in x)
|
||||
ckpts = {
|
||||
x.lower().strip()
|
||||
for x in re.findall(r'https://download.*\.pth', content)
|
||||
if 'mmocr' in x
|
||||
}
|
||||
|
||||
statsmsg = f"""
|
||||
## [{title}]({f})
|
||||
|
|
|
@ -27,7 +27,7 @@ author = 'OpenMMLab'
|
|||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../mmocr/version.py'
|
||||
with open(version_file, 'r') as f:
|
||||
with open(version_file) as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
__version__ = locals()['__version__']
|
||||
release = __version__
|
||||
|
|
|
@ -21,7 +21,7 @@ files = sorted(glob.glob('*_models.md'))
|
|||
stats = []
|
||||
|
||||
for f in files:
|
||||
with open(f, 'r') as content_file:
|
||||
with open(f) as content_file:
|
||||
content = content_file.read()
|
||||
|
||||
# Remove the blackquote notation from the paper link under the title
|
||||
|
@ -39,9 +39,8 @@ for f in files:
|
|||
exclude_expr = ''.join(f'(?!{s})' for s in exclude_papertype)
|
||||
expr = rf'<!-- \[{exclude_expr}([A-Z]+?)\] -->'\
|
||||
r'\s*\n.*?\btitle\s*=\s*{(.*?)}'
|
||||
papers = set(
|
||||
(papertype, titlecase.titlecase(paper.lower().strip()))
|
||||
for (papertype, paper) in re.findall(expr, content, re.DOTALL))
|
||||
papers = {(papertype, titlecase.titlecase(paper.lower().strip()))
|
||||
for (papertype, paper) in re.findall(expr, content, re.DOTALL)}
|
||||
print(papers)
|
||||
# paper links
|
||||
revcontent = '\n'.join(list(reversed(content.splitlines())))
|
||||
|
@ -56,13 +55,17 @@ for f in files:
|
|||
paperlist = '\n'.join(
|
||||
sorted(f' - [{t}] {paperlinks[x]}' for t, x in papers))
|
||||
# count configs
|
||||
configs = set(x.lower().strip()
|
||||
for x in re.findall(r'https.*configs/.*\.py', content))
|
||||
configs = {
|
||||
x.lower().strip()
|
||||
for x in re.findall(r'https.*configs/.*\.py', content)
|
||||
}
|
||||
|
||||
# count ckpts
|
||||
ckpts = set(x.lower().strip()
|
||||
for x in re.findall(r'https://download.*\.pth', content)
|
||||
if 'mmocr' in x)
|
||||
ckpts = {
|
||||
x.lower().strip()
|
||||
for x in re.findall(r'https://download.*\.pth', content)
|
||||
if 'mmocr' in x
|
||||
}
|
||||
|
||||
statsmsg = f"""
|
||||
## [{title}]({f})
|
||||
|
|
|
@ -118,12 +118,12 @@ def eval_ocr_metric(pred_texts, gt_texts, metric='acc'):
|
|||
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
|
||||
'char_recall', 'char_precision', 'one_minus_ned'
|
||||
]
|
||||
metric = set([metric]) if isinstance(metric, str) else set(metric)
|
||||
metric = {metric} if isinstance(metric, str) else set(metric)
|
||||
|
||||
supported_metrics = set([
|
||||
supported_metrics = {
|
||||
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
|
||||
'char_recall', 'char_precision', 'one_minus_ned'
|
||||
])
|
||||
}
|
||||
assert metric.issubset(supported_metrics)
|
||||
|
||||
match_res = count_matches(pred_texts, gt_texts)
|
||||
|
@ -160,6 +160,6 @@ def eval_ocr_metric(pred_texts, gt_texts, metric='acc'):
|
|||
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
|
||||
|
||||
for key, value in eval_res.items():
|
||||
eval_res[key] = float('{:.4f}'.format(value))
|
||||
eval_res[key] = float(f'{value:.4f}')
|
||||
|
||||
return eval_res
|
||||
|
|
|
@ -406,14 +406,14 @@ def imshow_node(img,
|
|||
True,
|
||||
color=(255, 255, 0),
|
||||
thickness=1)
|
||||
x_min = int(min([point[0] for point in new_box]))
|
||||
y_min = int(min([point[1] for point in new_box]))
|
||||
x_min = int(min(point[0] for point in new_box))
|
||||
y_min = int(min(point[1] for point in new_box))
|
||||
|
||||
# text
|
||||
pred_label = str(node_pred_label[i])
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||
pred_score = f'{node_pred_score[i]:.2f}'
|
||||
text = pred_label + '(' + pred_score + ')'
|
||||
texts.append(text)
|
||||
|
||||
|
@ -635,7 +635,7 @@ def is_contain_chinese(check_str):
|
|||
Return True if contains Chinese, else False.
|
||||
"""
|
||||
for ch in check_str:
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
if '\u4e00' <= ch <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -800,7 +800,7 @@ def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5):
|
|||
(pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.)
|
||||
score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.)
|
||||
# draw edge score
|
||||
cv2.putText(pred_edge_img, '{:.2f}'.format(pair_score),
|
||||
cv2.putText(pred_edge_img, f'{pair_score:.2f}',
|
||||
(score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4,
|
||||
score_color)
|
||||
# draw text for value
|
||||
|
|
|
@ -208,7 +208,7 @@ class KIEDataset(BaseDataset):
|
|||
|
||||
def pad_text_indices(self, text_inds):
|
||||
"""Pad text index to same length."""
|
||||
max_len = max([len(text_ind) for text_ind in text_inds])
|
||||
max_len = max(len(text_ind) for text_ind in text_inds)
|
||||
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
|
||||
for idx, text_ind in enumerate(text_inds):
|
||||
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
|
||||
|
|
|
@ -138,7 +138,7 @@ class LoadImageFromNdarray(LoadImageFromFile):
|
|||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromLMDB(object):
|
||||
class LoadImageFromLMDB:
|
||||
"""Load an image from lmdb file.
|
||||
|
||||
Similar with :obj:'LoadImageFromFile', but the image read from
|
||||
|
@ -169,8 +169,8 @@ class LoadImageFromLMDB(object):
|
|||
imgbuf = txn.get(img_key.encode('utf-8'))
|
||||
try:
|
||||
img = mmcv.imfrombytes(imgbuf, flag=self.color_type)
|
||||
except IOError:
|
||||
print('Corrupted image for {}'.format(img_key))
|
||||
except OSError:
|
||||
print(f'Corrupted image for {img_key}')
|
||||
return None
|
||||
|
||||
results['filename'] = img_key
|
||||
|
|
|
@ -79,7 +79,7 @@ class UniformConcatDataset(ConcatDataset):
|
|||
self.show_mean_scores = show_mean_scores
|
||||
if show_mean_scores is True or show_mean_scores == 'auto' and len(
|
||||
self.datasets) > 1:
|
||||
if len(set([type(ds) for ds in self.datasets])) != 1:
|
||||
if len({type(ds) for ds in self.datasets}) != 1:
|
||||
raise NotImplementedError(
|
||||
'To compute mean evaluation scores, all datasets'
|
||||
'must have the same type')
|
||||
|
|
|
@ -34,8 +34,7 @@ class LmdbAnnFileBackend:
|
|||
with env.begin(write=False) as txn:
|
||||
try:
|
||||
self.total_number = int(
|
||||
txn.get('num-samples'.encode('utf-8')).decode(
|
||||
self.encoding))
|
||||
txn.get(b'num-samples').decode(self.encoding))
|
||||
except AttributeError:
|
||||
warnings.warn(
|
||||
'DeprecationWarning: The lmdb dataset generated with '
|
||||
|
@ -45,8 +44,7 @@ class LmdbAnnFileBackend:
|
|||
'convert-text-recognition-dataset-to-lmdb-format for '
|
||||
'details.')
|
||||
self.total_number = int(
|
||||
txn.get('total_number'.encode('utf-8')).decode(
|
||||
self.encoding))
|
||||
txn.get(b'total_number').decode(self.encoding))
|
||||
self.deprecated_format = True
|
||||
# The lmdb file may contain only the label, or it may contain both
|
||||
# the label and the image, so we use image_key here for probing.
|
||||
|
|
|
@ -477,7 +477,7 @@ class UNet(BaseModule):
|
|||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
||||
self.encoder.append(nn.Sequential(*enc_conv_block))
|
||||
in_channels = base_channels * 2**i
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -53,7 +53,7 @@ class SDMGRHead(BaseModule):
|
|||
node_nums.append(text.size(0))
|
||||
char_nums.append((text > 0).sum(-1))
|
||||
|
||||
max_num = max([char_num.max() for char_num in char_nums])
|
||||
max_num = max(char_num.max() for char_num in char_nums)
|
||||
all_nodes = torch.cat([
|
||||
torch.cat(
|
||||
[text,
|
||||
|
|
|
@ -43,7 +43,7 @@ class GCN(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, feat_len):
|
||||
super(GCN, self).__init__()
|
||||
super().__init__()
|
||||
self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float()
|
||||
self.conv1 = GraphConv(feat_len, 512)
|
||||
self.conv2 = GraphConv(512, 256)
|
||||
|
|
|
@ -142,10 +142,9 @@ class LocalGraphs:
|
|||
assert isinstance(knn_batch, list)
|
||||
assert isinstance(sorted_dist_ind_batch, list)
|
||||
|
||||
num_max_nodes = max([
|
||||
num_max_nodes = max(
|
||||
len(pivot_local_graph) for pivot_local_graphs in local_graph_batch
|
||||
for pivot_local_graph in pivot_local_graphs
|
||||
])
|
||||
for pivot_local_graph in pivot_local_graphs)
|
||||
|
||||
local_graphs_node_feat = []
|
||||
adjacent_matrices = []
|
||||
|
|
|
@ -279,9 +279,8 @@ class ProposalLocalGraphs:
|
|||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
|
||||
num_max_nodes = max([
|
||||
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
|
||||
])
|
||||
num_max_nodes = max(
|
||||
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs)
|
||||
|
||||
local_graphs_node_feat = []
|
||||
adjacent_matrices = []
|
||||
|
|
|
@ -256,10 +256,11 @@ def connected_components(nodes, score_dict, link_thr):
|
|||
node_queue = [node]
|
||||
while node_queue:
|
||||
node = node_queue.pop(0)
|
||||
neighbors = set([
|
||||
neighbor for neighbor in node.links if
|
||||
neighbors = {
|
||||
neighbor
|
||||
for neighbor in node.links if
|
||||
score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
|
||||
])
|
||||
}
|
||||
neighbors.difference_update(cluster)
|
||||
nodes.difference_update(neighbors)
|
||||
cluster.update(neighbors)
|
||||
|
|
|
@ -140,6 +140,6 @@ class ResNet31OCR(BaseModule):
|
|||
outs.append(x)
|
||||
|
||||
if self.out_indices is not None:
|
||||
return tuple([outs[i] for i in self.out_indices])
|
||||
return tuple(outs[i] for i in self.out_indices)
|
||||
|
||||
return x
|
||||
|
|
|
@ -37,28 +37,27 @@ class VeryDeepVgg(BaseModule):
|
|||
def conv_relu(i, batch_normalization=False):
|
||||
n_in = input_channels if i == 0 else nm[i - 1]
|
||||
n_out = nm[i]
|
||||
cnn.add_module('conv{0}'.format(i),
|
||||
cnn.add_module(f'conv{i}',
|
||||
nn.Conv2d(n_in, n_out, ks[i], ss[i], ps[i]))
|
||||
if batch_normalization:
|
||||
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(n_out))
|
||||
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(n_out))
|
||||
if leaky_relu:
|
||||
cnn.add_module('relu{0}'.format(i),
|
||||
nn.LeakyReLU(0.2, inplace=True))
|
||||
cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True))
|
||||
else:
|
||||
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
||||
cnn.add_module(f'relu{i}', nn.ReLU(True))
|
||||
|
||||
conv_relu(0)
|
||||
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
||||
cnn.add_module(f'pooling{0}', nn.MaxPool2d(2, 2)) # 64x16x64
|
||||
conv_relu(1)
|
||||
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
||||
cnn.add_module(f'pooling{1}', nn.MaxPool2d(2, 2)) # 128x8x32
|
||||
conv_relu(2, True)
|
||||
conv_relu(3)
|
||||
cnn.add_module('pooling{0}'.format(2),
|
||||
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
||||
cnn.add_module(f'pooling{2}', nn.MaxPool2d((2, 2), (2, 1),
|
||||
(0, 1))) # 256x4x16
|
||||
conv_relu(4, True)
|
||||
conv_relu(5)
|
||||
cnn.add_module('pooling{0}'.format(3),
|
||||
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
||||
cnn.add_module(f'pooling{3}', nn.MaxPool2d((2, 2), (2, 1),
|
||||
(0, 1))) # 512x2x16
|
||||
conv_relu(6, True) # 512x1x16
|
||||
|
||||
self.cnn = cnn
|
||||
|
|
|
@ -21,7 +21,7 @@ def clones(module, N):
|
|||
class Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, d_model, vocab):
|
||||
super(Embeddings, self).__init__()
|
||||
super().__init__()
|
||||
self.lut = nn.Embedding(vocab, d_model)
|
||||
self.d_model = d_model
|
||||
|
||||
|
@ -70,7 +70,7 @@ class MasterDecoder(BaseDecoder):
|
|||
max_seq_len=30,
|
||||
init_cfg=None,
|
||||
):
|
||||
super(MasterDecoder, self).__init__(init_cfg=init_cfg)
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
operation_order = ('norm', 'self_attn', 'norm', 'cross_attn', 'norm',
|
||||
'ffn')
|
||||
|
|
|
@ -31,7 +31,7 @@ class BasicBlock(nn.Module):
|
|||
downsample=None,
|
||||
use_conv1x1=False,
|
||||
plugins=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
if use_conv1x1:
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
|
|
|
@ -40,7 +40,7 @@ class ABILoss(nn.Module):
|
|||
|
||||
def _flatten(self, logits, target_lens):
|
||||
flatten_logits = torch.cat(
|
||||
[s[:target_lens[i]] for i, s in enumerate((logits))])
|
||||
[s[:target_lens[i]] for i, s in enumerate(logits)])
|
||||
return flatten_logits
|
||||
|
||||
def _ce_loss(self, logits, targets):
|
||||
|
|
|
@ -15,7 +15,7 @@ class Maxpool2d(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, kernel_size, stride, padding=0, **kwargs):
|
||||
super(Maxpool2d, self).__init__()
|
||||
super().__init__()
|
||||
self.model = nn.MaxPool2d(kernel_size, stride, padding)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -53,7 +53,7 @@ class GCAModule(nn.Module):
|
|||
scale_attn=False,
|
||||
fusion_type='channel_add',
|
||||
**kwargs):
|
||||
super(GCAModule, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
assert pooling_type in ['avg', 'att']
|
||||
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
|
||||
|
|
|
@ -32,7 +32,7 @@ def list_from_file(filename, encoding='utf-8'):
|
|||
list[str]: A list of strings.
|
||||
"""
|
||||
item_list = []
|
||||
with open(filename, 'r', encoding=encoding) as f:
|
||||
with open(filename, encoding=encoding) as f:
|
||||
for line in f:
|
||||
item_list.append(line.rstrip('\n\r'))
|
||||
return item_list
|
||||
|
|
|
@ -49,9 +49,10 @@ def generate_sample_dataloader(cfg, curr_dir, img_prefix='', ann_file=''):
|
|||
dataset = DATASETS.build(cfg.data.test)
|
||||
|
||||
loader_cfg = {
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
'workers_per_gpu', 'samples_per_gpu'
|
||||
] if k in cfg.data)
|
||||
**{
|
||||
k: cfg.data[k]
|
||||
for k in ['workers_per_gpu', 'samples_per_gpu'] if k in cfg.data
|
||||
}
|
||||
}
|
||||
test_loader_cfg = {
|
||||
**loader_cfg,
|
||||
|
@ -144,9 +145,10 @@ def gene_sdmgr_model_dataloader(cfg, dirname, curr_dir, empty_img=False):
|
|||
dataset = DATASETS.build(cfg.data.test)
|
||||
|
||||
loader_cfg = {
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
'workers_per_gpu', 'samples_per_gpu'
|
||||
] if k in cfg.data)
|
||||
**{
|
||||
k: cfg.data[k]
|
||||
for k in ['workers_per_gpu', 'samples_per_gpu'] if k in cfg.data
|
||||
}
|
||||
}
|
||||
test_loader_cfg = {
|
||||
**loader_cfg,
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_get_gt_mask():
|
|||
|
||||
|
||||
def test_eval_hmean():
|
||||
metrics = set(['hmean-iou', 'hmean-ic13'])
|
||||
metrics = {'hmean-iou', 'hmean-ic13'}
|
||||
results = [{
|
||||
'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1],
|
||||
[120, 140, 200, 140, 200, 200, 120, 200, 1]]
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_list_to_file():
|
|||
list_to_file(filename, lines)
|
||||
lines2 = [
|
||||
line.rstrip('\r\n')
|
||||
for line in open(filename, 'r', encoding='utf-8').readlines()
|
||||
for line in open(filename, encoding='utf-8').readlines()
|
||||
]
|
||||
lines = list(map(str, lines))
|
||||
assert len(lines) == len(lines2)
|
||||
|
@ -74,7 +74,7 @@ def test_list_to_file():
|
|||
list_to_file(filename, [json.dumps(line) for line in lines])
|
||||
lines2 = [
|
||||
json.loads(line.rstrip('\r\n'))['text']
|
||||
for line in open(filename, 'r', encoding='utf-8').readlines()
|
||||
for line in open(filename, encoding='utf-8').readlines()
|
||||
][0]
|
||||
|
||||
lines = list(lines[0]['text'])
|
||||
|
|
7
setup.py
7
setup.py
|
@ -77,7 +77,7 @@ def add_mim_extension():
|
|||
|
||||
|
||||
def get_version():
|
||||
with open(version_file, 'r') as f:
|
||||
with open(version_file) as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
import sys
|
||||
|
||||
|
@ -137,12 +137,11 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
yield info
|
||||
|
||||
def parse_require_file(fpath):
|
||||
with open(fpath, 'r') as f:
|
||||
with open(fpath) as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
for info in parse_line(line):
|
||||
yield info
|
||||
yield from parse_line(line)
|
||||
|
||||
def gen_packages_items():
|
||||
if exists(require_fpath):
|
||||
|
|
|
@ -160,7 +160,7 @@ def load_json_logs(json_logs):
|
|||
# value of sub dict is a list of corresponding values of all iterations
|
||||
log_dicts = [dict() for _ in json_logs]
|
||||
for json_log, log_dict in zip(json_logs, log_dicts):
|
||||
with open(json_log, 'r') as log_file:
|
||||
with open(json_log) as log_file:
|
||||
for line in log_file:
|
||||
log = json.loads(line.strip())
|
||||
# skip lines without `epoch` field
|
||||
|
|
|
@ -106,7 +106,7 @@ def load_txt_info(gt_file, img_info):
|
|||
Returns:
|
||||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
with open(gt_file, 'r', encoding='latin1') as f:
|
||||
with open(gt_file, encoding='latin1') as f:
|
||||
anno_info = []
|
||||
for line in f:
|
||||
line = line.strip('\n')
|
||||
|
|
|
@ -101,7 +101,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
|
||||
with open(gt_file, 'r') as f:
|
||||
with open(gt_file) as f:
|
||||
anno_info = []
|
||||
annotations = f.readlines()
|
||||
for ann in annotations:
|
||||
|
|
|
@ -92,7 +92,7 @@ def collect_hiertext_info(root_path, level, split, print_every=1000):
|
|||
raise Exception(
|
||||
f'{annotation_path} not exists, please check and try again.')
|
||||
|
||||
annotation = json.load(open(annotation_path, 'r'))['annotations']
|
||||
annotation = json.load(open(annotation_path))['annotations']
|
||||
img_infos = []
|
||||
for i, img_annos in enumerate(annotation):
|
||||
if i > 0 and i % print_every == 0:
|
||||
|
|
|
@ -119,7 +119,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
anno_info = []
|
||||
with open(gt_file, 'r') as f:
|
||||
with open(gt_file) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
xmin, ymin, xmax, ymax = line.split(',')[0:4]
|
||||
|
|
|
@ -112,7 +112,7 @@ def load_txt_info(gt_file, img_info, separator):
|
|||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
anno_info = []
|
||||
with open(gt_file, 'r') as f:
|
||||
with open(gt_file) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
xmin, ymin, xmax, ymax = line.split(separator)[0:4]
|
||||
|
|
|
@ -136,7 +136,7 @@ def load_txt_info(gt_file, img_info):
|
|||
"""
|
||||
|
||||
anno_info = []
|
||||
with open(gt_file, 'r') as f:
|
||||
with open(gt_file) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
points = line.split(',')[0:8]
|
||||
|
|
|
@ -120,7 +120,7 @@ def load_txt_info(gt_file, img_info):
|
|||
"""
|
||||
|
||||
anno_info = []
|
||||
with open(gt_file, 'r', encoding='utf-8-sig') as f:
|
||||
with open(gt_file, encoding='utf-8-sig') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
points = line.split(',')[0:8]
|
||||
|
|
|
@ -107,7 +107,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (list): The dict of the img and annotation information
|
||||
"""
|
||||
|
||||
with open(gt_file, 'r', encoding='unicode_escape') as f:
|
||||
with open(gt_file, encoding='unicode_escape') as f:
|
||||
anno_info = []
|
||||
for ann in f.readlines():
|
||||
|
||||
|
|
|
@ -150,7 +150,7 @@ def convert_annotations(root_path, gt_name, lmdb_name):
|
|||
img_json['annotations'].append(anno_info)
|
||||
string = json.dumps(img_json)
|
||||
txn.put(str(img_id).encode('utf8'), string.encode('utf8'))
|
||||
key = 'total_number'.encode('utf8')
|
||||
key = b'total_number'
|
||||
value = str(img_num).encode('utf8')
|
||||
txn.put(key, value)
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ def collect_files(img_dir, gt_dir):
|
|||
|
||||
imgs_list = sorted(imgs_list)
|
||||
ann_list = sorted(
|
||||
[osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
|
||||
osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir))
|
||||
|
||||
files = list(zip(imgs_list, ann_list))
|
||||
assert len(files), f'No images found in {img_dir}'
|
||||
|
@ -230,7 +230,7 @@ def get_contours_txt(gt_path):
|
|||
contours = []
|
||||
words = []
|
||||
|
||||
with open(gt_path, 'r') as f:
|
||||
with open(gt_path) as f:
|
||||
tmp_line = ''
|
||||
for idx, line in enumerate(f):
|
||||
line = line.strip()
|
||||
|
|
|
@ -109,7 +109,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
|
||||
with open(gt_file, 'r', encoding='utf-8') as f:
|
||||
with open(gt_file, encoding='utf-8') as f:
|
||||
anno_info = []
|
||||
for line in f:
|
||||
line = line.strip('\n')
|
||||
|
|
|
@ -108,7 +108,7 @@ def load_txt_info(gt_file, img_info):
|
|||
Returns:
|
||||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
with open(gt_file, 'r', encoding='latin1') as f:
|
||||
with open(gt_file, encoding='latin1') as f:
|
||||
anno_info = []
|
||||
for line in f:
|
||||
line = line.strip('\n')
|
||||
|
|
|
@ -100,7 +100,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
|
||||
with open(gt_file, 'r') as f:
|
||||
with open(gt_file) as f:
|
||||
anno_info = []
|
||||
annotations = f.readlines()
|
||||
for ann in annotations:
|
||||
|
|
|
@ -170,7 +170,7 @@ def convert_hiertext(
|
|||
raise Exception(
|
||||
f'{annotation_path} not exists, please check and try again.')
|
||||
|
||||
annotation = json.load(open(annotation_path, 'r'))['annotations']
|
||||
annotation = json.load(open(annotation_path))['annotations']
|
||||
# outputs
|
||||
dst_label_file = osp.join(root_path, f'{split}_label.{format}')
|
||||
dst_image_root = osp.join(root_path, 'crops', split)
|
||||
|
|
|
@ -36,7 +36,6 @@ def convert_annotations(root_path, split, format):
|
|||
with open(
|
||||
osp.join(root_path, 'annotations',
|
||||
f'Challenge1_{split}_Task3_GT.txt'),
|
||||
'r',
|
||||
encoding='"utf-8-sig') as f:
|
||||
annos = f.readlines()
|
||||
dst_image_root = osp.join(root_path, split.lower())
|
||||
|
|
|
@ -36,7 +36,6 @@ def convert_annotations(root_path, split, format):
|
|||
with open(
|
||||
osp.join(root_path, 'annotations',
|
||||
f'Challenge2_{split}_Task3_GT.txt'),
|
||||
'r',
|
||||
encoding='"utf-8-sig') as f:
|
||||
annos = f.readlines()
|
||||
dst_image_root = osp.join(root_path, split.lower())
|
||||
|
|
|
@ -36,7 +36,6 @@ def convert_annotations(root_path, split, format):
|
|||
lines = []
|
||||
with open(
|
||||
osp.join(root_path, f'{split}_label.txt'),
|
||||
'r',
|
||||
encoding='"utf-8-sig') as f:
|
||||
annos = f.readlines()
|
||||
for anno in annos:
|
||||
|
|
|
@ -138,7 +138,7 @@ def load_txt_info(gt_file, img_info):
|
|||
"""
|
||||
|
||||
anno_info = []
|
||||
with open(gt_file, 'r') as f:
|
||||
with open(gt_file) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
points = line.split(',')[0:8]
|
||||
|
|
|
@ -122,7 +122,7 @@ def load_txt_info(gt_file, img_info):
|
|||
"""
|
||||
|
||||
anno_info = []
|
||||
with open(gt_file, 'r', encoding='utf-8-sig') as f:
|
||||
with open(gt_file, encoding='utf-8-sig') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
points = line.split(',')[0:8]
|
||||
|
|
|
@ -112,7 +112,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (list): The dict of the img and annotation information
|
||||
"""
|
||||
|
||||
with open(gt_file, 'r', encoding='unicode_escape') as f:
|
||||
with open(gt_file, encoding='unicode_escape') as f:
|
||||
anno_info = []
|
||||
for ann in f.readlines():
|
||||
# skip invalid annotation line
|
||||
|
|
|
@ -41,7 +41,7 @@ def collect_files(img_dir, gt_dir):
|
|||
|
||||
imgs_list = sorted(imgs_list)
|
||||
ann_list = sorted(
|
||||
[osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
|
||||
osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir))
|
||||
|
||||
files = [(img_file, gt_file)
|
||||
for (img_file, gt_file) in zip(imgs_list, ann_list)]
|
||||
|
@ -218,7 +218,7 @@ def get_contours_txt(gt_path):
|
|||
contours = []
|
||||
words = []
|
||||
|
||||
with open(gt_path, 'r') as f:
|
||||
with open(gt_path) as f:
|
||||
tmp_line = ''
|
||||
for idx, line in enumerate(f):
|
||||
line = line.strip()
|
||||
|
|
|
@ -111,7 +111,7 @@ def load_txt_info(gt_file, img_info):
|
|||
img_info (dict): The dict of the img and annotation information
|
||||
"""
|
||||
|
||||
with open(gt_file, 'r', encoding='utf-8') as f:
|
||||
with open(gt_file, encoding='utf-8') as f:
|
||||
anno_info = []
|
||||
for line in f:
|
||||
line = line.strip('\n')
|
||||
|
|
|
@ -161,7 +161,7 @@ def onnx2tensorrt(onnx_file: str,
|
|||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
print('The outputs are {} between TensorRT and ONNX'.format(same_diff))
|
||||
print(f'The outputs are {same_diff} between TensorRT and ONNX')
|
||||
|
||||
if show:
|
||||
onnx_img = onnx_model.show_result(
|
||||
|
@ -259,7 +259,7 @@ if __name__ == '__main__':
|
|||
assert osp.exists(args.model_config), 'Config {} not found.'.format(
|
||||
args.model_config)
|
||||
assert osp.exists(args.onnx_file), \
|
||||
'ONNX model {} not found.'.format(args.onnx_file)
|
||||
f'ONNX model {args.onnx_file} not found.'
|
||||
assert args.workspace_size >= 0, 'Workspace size less than 0.'
|
||||
for max_value, min_value in zip(args.max_shape, args.min_shape):
|
||||
assert max_value >= min_value, \
|
||||
|
|
|
@ -259,7 +259,7 @@ def pytorch2onnx(model: nn.Module,
|
|||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
print('The outputs are {} between PyTorch and ONNX'.format(same_diff))
|
||||
print(f'The outputs are {same_diff} between PyTorch and ONNX')
|
||||
|
||||
if show:
|
||||
onnx_img = onnx_model.show_result(
|
||||
|
|
|
@ -26,7 +26,7 @@ def process_checkpoint(in_file, out_file):
|
|||
checkpoint['meta'] = {'CLASSES': 0}
|
||||
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
|
||||
sha = subprocess.check_output(['sha256sum', out_file]).decode()
|
||||
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
|
||||
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
|
||||
subprocess.Popen(['mv', out_file, final_file])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue