Add PyUpgrade pre-commit hook

pull/1178/head
gaotongxiao 2022-05-12 11:19:18 +08:00
parent 593d7529a3
commit 536dfdd4bd
55 changed files with 117 additions and 112 deletions

2
.gitignore vendored
View File

@ -107,7 +107,7 @@ venv.bak/
# cython generated cpp
!data/dict
data/*
data
.vscode
.idea

View File

@ -43,6 +43,11 @@ repos:
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://github.com/open-mmlab/pre-commit-hooks
rev: v0.2.0 # Use the ref you want to point at
hooks:

View File

@ -27,7 +27,7 @@ author = 'OpenMMLab'
# The full version, including alpha/beta/rc tags
version_file = '../../mmocr/version.py'
with open(version_file, 'r') as f:
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__']
release = __version__

View File

@ -21,7 +21,7 @@ files = sorted(glob.glob('*_models.md'))
stats = []
for f in files:
with open(f, 'r') as content_file:
with open(f) as content_file:
content = content_file.read()
# Remove the blackquote notation from the paper link under the title
@ -39,9 +39,8 @@ for f in files:
exclude_expr = ''.join(f'(?!{s})' for s in exclude_papertype)
expr = rf'<!-- \[{exclude_expr}([A-Z]+?)\] -->'\
r'\s*\n.*?\btitle\s*=\s*{(.*?)}'
papers = set(
(papertype, titlecase.titlecase(paper.lower().strip()))
for (papertype, paper) in re.findall(expr, content, re.DOTALL))
papers = {(papertype, titlecase.titlecase(paper.lower().strip()))
for (papertype, paper) in re.findall(expr, content, re.DOTALL)}
print(papers)
# paper links
revcontent = '\n'.join(list(reversed(content.splitlines())))
@ -56,13 +55,17 @@ for f in files:
paperlist = '\n'.join(
sorted(f' - [{t}] {paperlinks[x]}' for t, x in papers))
# count configs
configs = set(x.lower().strip()
for x in re.findall(r'https.*configs/.*\.py', content))
configs = {
x.lower().strip()
for x in re.findall(r'https.*configs/.*\.py', content)
}
# count ckpts
ckpts = set(x.lower().strip()
for x in re.findall(r'https://download.*\.pth', content)
if 'mmocr' in x)
ckpts = {
x.lower().strip()
for x in re.findall(r'https://download.*\.pth', content)
if 'mmocr' in x
}
statsmsg = f"""
## [{title}]({f})

View File

@ -27,7 +27,7 @@ author = 'OpenMMLab'
# The full version, including alpha/beta/rc tags
version_file = '../../mmocr/version.py'
with open(version_file, 'r') as f:
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__']
release = __version__

View File

@ -21,7 +21,7 @@ files = sorted(glob.glob('*_models.md'))
stats = []
for f in files:
with open(f, 'r') as content_file:
with open(f) as content_file:
content = content_file.read()
# Remove the blackquote notation from the paper link under the title
@ -39,9 +39,8 @@ for f in files:
exclude_expr = ''.join(f'(?!{s})' for s in exclude_papertype)
expr = rf'<!-- \[{exclude_expr}([A-Z]+?)\] -->'\
r'\s*\n.*?\btitle\s*=\s*{(.*?)}'
papers = set(
(papertype, titlecase.titlecase(paper.lower().strip()))
for (papertype, paper) in re.findall(expr, content, re.DOTALL))
papers = {(papertype, titlecase.titlecase(paper.lower().strip()))
for (papertype, paper) in re.findall(expr, content, re.DOTALL)}
print(papers)
# paper links
revcontent = '\n'.join(list(reversed(content.splitlines())))
@ -56,13 +55,17 @@ for f in files:
paperlist = '\n'.join(
sorted(f' - [{t}] {paperlinks[x]}' for t, x in papers))
# count configs
configs = set(x.lower().strip()
for x in re.findall(r'https.*configs/.*\.py', content))
configs = {
x.lower().strip()
for x in re.findall(r'https.*configs/.*\.py', content)
}
# count ckpts
ckpts = set(x.lower().strip()
for x in re.findall(r'https://download.*\.pth', content)
if 'mmocr' in x)
ckpts = {
x.lower().strip()
for x in re.findall(r'https://download.*\.pth', content)
if 'mmocr' in x
}
statsmsg = f"""
## [{title}]({f})

View File

@ -118,12 +118,12 @@ def eval_ocr_metric(pred_texts, gt_texts, metric='acc'):
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
'char_recall', 'char_precision', 'one_minus_ned'
]
metric = set([metric]) if isinstance(metric, str) else set(metric)
metric = {metric} if isinstance(metric, str) else set(metric)
supported_metrics = set([
supported_metrics = {
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
'char_recall', 'char_precision', 'one_minus_ned'
])
}
assert metric.issubset(supported_metrics)
match_res = count_matches(pred_texts, gt_texts)
@ -160,6 +160,6 @@ def eval_ocr_metric(pred_texts, gt_texts, metric='acc'):
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
for key, value in eval_res.items():
eval_res[key] = float('{:.4f}'.format(value))
eval_res[key] = float(f'{value:.4f}')
return eval_res

View File

@ -406,14 +406,14 @@ def imshow_node(img,
True,
color=(255, 255, 0),
thickness=1)
x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box]))
x_min = int(min(point[0] for point in new_box))
y_min = int(min(point[1] for point in new_box))
# text
pred_label = str(node_pred_label[i])
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i])
pred_score = f'{node_pred_score[i]:.2f}'
text = pred_label + '(' + pred_score + ')'
texts.append(text)
@ -635,7 +635,7 @@ def is_contain_chinese(check_str):
Return True if contains Chinese, else False.
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
if '\u4e00' <= ch <= '\u9fff':
return True
return False
@ -800,7 +800,7 @@ def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5):
(pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.)
score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.)
# draw edge score
cv2.putText(pred_edge_img, '{:.2f}'.format(pair_score),
cv2.putText(pred_edge_img, f'{pair_score:.2f}',
(score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4,
score_color)
# draw text for value

View File

@ -208,7 +208,7 @@ class KIEDataset(BaseDataset):
def pad_text_indices(self, text_inds):
"""Pad text index to same length."""
max_len = max([len(text_ind) for text_ind in text_inds])
max_len = max(len(text_ind) for text_ind in text_inds)
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
for idx, text_ind in enumerate(text_inds):
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)

View File

@ -138,7 +138,7 @@ class LoadImageFromNdarray(LoadImageFromFile):
@TRANSFORMS.register_module()
class LoadImageFromLMDB(object):
class LoadImageFromLMDB:
"""Load an image from lmdb file.
Similar with :obj:'LoadImageFromFile', but the image read from
@ -169,8 +169,8 @@ class LoadImageFromLMDB(object):
imgbuf = txn.get(img_key.encode('utf-8'))
try:
img = mmcv.imfrombytes(imgbuf, flag=self.color_type)
except IOError:
print('Corrupted image for {}'.format(img_key))
except OSError:
print(f'Corrupted image for {img_key}')
return None
results['filename'] = img_key

View File

@ -79,7 +79,7 @@ class UniformConcatDataset(ConcatDataset):
self.show_mean_scores = show_mean_scores
if show_mean_scores is True or show_mean_scores == 'auto' and len(
self.datasets) > 1:
if len(set([type(ds) for ds in self.datasets])) != 1:
if len({type(ds) for ds in self.datasets}) != 1:
raise NotImplementedError(
'To compute mean evaluation scores, all datasets'
'must have the same type')

View File

@ -34,8 +34,7 @@ class LmdbAnnFileBackend:
with env.begin(write=False) as txn:
try:
self.total_number = int(
txn.get('num-samples'.encode('utf-8')).decode(
self.encoding))
txn.get(b'num-samples').decode(self.encoding))
except AttributeError:
warnings.warn(
'DeprecationWarning: The lmdb dataset generated with '
@ -45,8 +44,7 @@ class LmdbAnnFileBackend:
'convert-text-recognition-dataset-to-lmdb-format for '
'details.')
self.total_number = int(
txn.get('total_number'.encode('utf-8')).decode(
self.encoding))
txn.get(b'total_number').decode(self.encoding))
self.deprecated_format = True
# The lmdb file may contain only the label, or it may contain both
# the label and the image, so we use image_key here for probing.

View File

@ -477,7 +477,7 @@ class UNet(BaseModule):
act_cfg=act_cfg,
dcn=None,
plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block)))
self.encoder.append(nn.Sequential(*enc_conv_block))
in_channels = base_channels * 2**i
def forward(self, x):

View File

@ -53,7 +53,7 @@ class SDMGRHead(BaseModule):
node_nums.append(text.size(0))
char_nums.append((text > 0).sum(-1))
max_num = max([char_num.max() for char_num in char_nums])
max_num = max(char_num.max() for char_num in char_nums)
all_nodes = torch.cat([
torch.cat(
[text,

View File

@ -43,7 +43,7 @@ class GCN(nn.Module):
"""
def __init__(self, feat_len):
super(GCN, self).__init__()
super().__init__()
self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float()
self.conv1 = GraphConv(feat_len, 512)
self.conv2 = GraphConv(512, 256)

View File

@ -142,10 +142,9 @@ class LocalGraphs:
assert isinstance(knn_batch, list)
assert isinstance(sorted_dist_ind_batch, list)
num_max_nodes = max([
num_max_nodes = max(
len(pivot_local_graph) for pivot_local_graphs in local_graph_batch
for pivot_local_graph in pivot_local_graphs
])
for pivot_local_graph in pivot_local_graphs)
local_graphs_node_feat = []
adjacent_matrices = []

View File

@ -279,9 +279,8 @@ class ProposalLocalGraphs:
pivot_local_graphs.append(pivot_local_graph)
pivot_knns.append(pivot_knn)
num_max_nodes = max([
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
])
num_max_nodes = max(
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs)
local_graphs_node_feat = []
adjacent_matrices = []

View File

@ -256,10 +256,11 @@ def connected_components(nodes, score_dict, link_thr):
node_queue = [node]
while node_queue:
node = node_queue.pop(0)
neighbors = set([
neighbor for neighbor in node.links if
neighbors = {
neighbor
for neighbor in node.links if
score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
])
}
neighbors.difference_update(cluster)
nodes.difference_update(neighbors)
cluster.update(neighbors)

View File

@ -140,6 +140,6 @@ class ResNet31OCR(BaseModule):
outs.append(x)
if self.out_indices is not None:
return tuple([outs[i] for i in self.out_indices])
return tuple(outs[i] for i in self.out_indices)
return x

View File

@ -37,28 +37,27 @@ class VeryDeepVgg(BaseModule):
def conv_relu(i, batch_normalization=False):
n_in = input_channels if i == 0 else nm[i - 1]
n_out = nm[i]
cnn.add_module('conv{0}'.format(i),
cnn.add_module(f'conv{i}',
nn.Conv2d(n_in, n_out, ks[i], ss[i], ps[i]))
if batch_normalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(n_out))
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(n_out))
if leaky_relu:
cnn.add_module('relu{0}'.format(i),
nn.LeakyReLU(0.2, inplace=True))
cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
cnn.add_module(f'relu{i}', nn.ReLU(True))
conv_relu(0)
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
cnn.add_module(f'pooling{0}', nn.MaxPool2d(2, 2)) # 64x16x64
conv_relu(1)
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
cnn.add_module(f'pooling{1}', nn.MaxPool2d(2, 2)) # 128x8x32
conv_relu(2, True)
conv_relu(3)
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
cnn.add_module(f'pooling{2}', nn.MaxPool2d((2, 2), (2, 1),
(0, 1))) # 256x4x16
conv_relu(4, True)
conv_relu(5)
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
cnn.add_module(f'pooling{3}', nn.MaxPool2d((2, 2), (2, 1),
(0, 1))) # 512x2x16
conv_relu(6, True) # 512x1x16
self.cnn = cnn

View File

@ -21,7 +21,7 @@ def clones(module, N):
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
super().__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
@ -70,7 +70,7 @@ class MasterDecoder(BaseDecoder):
max_seq_len=30,
init_cfg=None,
):
super(MasterDecoder, self).__init__(init_cfg=init_cfg)
super().__init__(init_cfg=init_cfg)
operation_order = ('norm', 'self_attn', 'norm', 'cross_attn', 'norm',
'ffn')

View File

@ -31,7 +31,7 @@ class BasicBlock(nn.Module):
downsample=None,
use_conv1x1=False,
plugins=None):
super(BasicBlock, self).__init__()
super().__init__()
if use_conv1x1:
self.conv1 = conv1x1(inplanes, planes)

View File

@ -40,7 +40,7 @@ class ABILoss(nn.Module):
def _flatten(self, logits, target_lens):
flatten_logits = torch.cat(
[s[:target_lens[i]] for i, s in enumerate((logits))])
[s[:target_lens[i]] for i, s in enumerate(logits)])
return flatten_logits
def _ce_loss(self, logits, targets):

View File

@ -15,7 +15,7 @@ class Maxpool2d(nn.Module):
"""
def __init__(self, kernel_size, stride, padding=0, **kwargs):
super(Maxpool2d, self).__init__()
super().__init__()
self.model = nn.MaxPool2d(kernel_size, stride, padding)
def forward(self, x):
@ -53,7 +53,7 @@ class GCAModule(nn.Module):
scale_attn=False,
fusion_type='channel_add',
**kwargs):
super(GCAModule, self).__init__()
super().__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']

View File

@ -32,7 +32,7 @@ def list_from_file(filename, encoding='utf-8'):
list[str]: A list of strings.
"""
item_list = []
with open(filename, 'r', encoding=encoding) as f:
with open(filename, encoding=encoding) as f:
for line in f:
item_list.append(line.rstrip('\n\r'))
return item_list

View File

@ -49,9 +49,10 @@ def generate_sample_dataloader(cfg, curr_dir, img_prefix='', ann_file=''):
dataset = DATASETS.build(cfg.data.test)
loader_cfg = {
**dict((k, cfg.data[k]) for k in [
'workers_per_gpu', 'samples_per_gpu'
] if k in cfg.data)
**{
k: cfg.data[k]
for k in ['workers_per_gpu', 'samples_per_gpu'] if k in cfg.data
}
}
test_loader_cfg = {
**loader_cfg,
@ -144,9 +145,10 @@ def gene_sdmgr_model_dataloader(cfg, dirname, curr_dir, empty_img=False):
dataset = DATASETS.build(cfg.data.test)
loader_cfg = {
**dict((k, cfg.data[k]) for k in [
'workers_per_gpu', 'samples_per_gpu'
] if k in cfg.data)
**{
k: cfg.data[k]
for k in ['workers_per_gpu', 'samples_per_gpu'] if k in cfg.data
}
}
test_loader_cfg = {
**loader_cfg,

View File

@ -47,7 +47,7 @@ def test_get_gt_mask():
def test_eval_hmean():
metrics = set(['hmean-iou', 'hmean-ic13'])
metrics = {'hmean-iou', 'hmean-ic13'}
results = [{
'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1],
[120, 140, 200, 140, 200, 200, 120, 200, 1]]

View File

@ -63,7 +63,7 @@ def test_list_to_file():
list_to_file(filename, lines)
lines2 = [
line.rstrip('\r\n')
for line in open(filename, 'r', encoding='utf-8').readlines()
for line in open(filename, encoding='utf-8').readlines()
]
lines = list(map(str, lines))
assert len(lines) == len(lines2)
@ -74,7 +74,7 @@ def test_list_to_file():
list_to_file(filename, [json.dumps(line) for line in lines])
lines2 = [
json.loads(line.rstrip('\r\n'))['text']
for line in open(filename, 'r', encoding='utf-8').readlines()
for line in open(filename, encoding='utf-8').readlines()
][0]
lines = list(lines[0]['text'])

View File

@ -77,7 +77,7 @@ def add_mim_extension():
def get_version():
with open(version_file, 'r') as f:
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
import sys
@ -137,12 +137,11 @@ def parse_requirements(fname='requirements.txt', with_version=True):
yield info
def parse_require_file(fpath):
with open(fpath, 'r') as f:
with open(fpath) as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
yield from parse_line(line)
def gen_packages_items():
if exists(require_fpath):

View File

@ -160,7 +160,7 @@ def load_json_logs(json_logs):
# value of sub dict is a list of corresponding values of all iterations
log_dicts = [dict() for _ in json_logs]
for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log, 'r') as log_file:
with open(json_log) as log_file:
for line in log_file:
log = json.loads(line.strip())
# skip lines without `epoch` field

View File

@ -106,7 +106,7 @@ def load_txt_info(gt_file, img_info):
Returns:
img_info (dict): The dict of the img and annotation information
"""
with open(gt_file, 'r', encoding='latin1') as f:
with open(gt_file, encoding='latin1') as f:
anno_info = []
for line in f:
line = line.strip('\n')

View File

@ -101,7 +101,7 @@ def load_txt_info(gt_file, img_info):
img_info (dict): The dict of the img and annotation information
"""
with open(gt_file, 'r') as f:
with open(gt_file) as f:
anno_info = []
annotations = f.readlines()
for ann in annotations:

View File

@ -92,7 +92,7 @@ def collect_hiertext_info(root_path, level, split, print_every=1000):
raise Exception(
f'{annotation_path} not exists, please check and try again.')
annotation = json.load(open(annotation_path, 'r'))['annotations']
annotation = json.load(open(annotation_path))['annotations']
img_infos = []
for i, img_annos in enumerate(annotation):
if i > 0 and i % print_every == 0:

View File

@ -119,7 +119,7 @@ def load_txt_info(gt_file, img_info):
img_info (dict): The dict of the img and annotation information
"""
anno_info = []
with open(gt_file, 'r') as f:
with open(gt_file) as f:
lines = f.readlines()
for line in lines:
xmin, ymin, xmax, ymax = line.split(',')[0:4]

View File

@ -112,7 +112,7 @@ def load_txt_info(gt_file, img_info, separator):
img_info (dict): The dict of the img and annotation information
"""
anno_info = []
with open(gt_file, 'r') as f:
with open(gt_file) as f:
lines = f.readlines()
for line in lines:
xmin, ymin, xmax, ymax = line.split(separator)[0:4]

View File

@ -136,7 +136,7 @@ def load_txt_info(gt_file, img_info):
"""
anno_info = []
with open(gt_file, 'r') as f:
with open(gt_file) as f:
lines = f.readlines()
for line in lines:
points = line.split(',')[0:8]

View File

@ -120,7 +120,7 @@ def load_txt_info(gt_file, img_info):
"""
anno_info = []
with open(gt_file, 'r', encoding='utf-8-sig') as f:
with open(gt_file, encoding='utf-8-sig') as f:
lines = f.readlines()
for line in lines:
points = line.split(',')[0:8]

View File

@ -107,7 +107,7 @@ def load_txt_info(gt_file, img_info):
img_info (list): The dict of the img and annotation information
"""
with open(gt_file, 'r', encoding='unicode_escape') as f:
with open(gt_file, encoding='unicode_escape') as f:
anno_info = []
for ann in f.readlines():

View File

@ -150,7 +150,7 @@ def convert_annotations(root_path, gt_name, lmdb_name):
img_json['annotations'].append(anno_info)
string = json.dumps(img_json)
txn.put(str(img_id).encode('utf8'), string.encode('utf8'))
key = 'total_number'.encode('utf8')
key = b'total_number'
value = str(img_num).encode('utf8')
txn.put(key, value)

View File

@ -41,7 +41,7 @@ def collect_files(img_dir, gt_dir):
imgs_list = sorted(imgs_list)
ann_list = sorted(
[osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir))
files = list(zip(imgs_list, ann_list))
assert len(files), f'No images found in {img_dir}'
@ -230,7 +230,7 @@ def get_contours_txt(gt_path):
contours = []
words = []
with open(gt_path, 'r') as f:
with open(gt_path) as f:
tmp_line = ''
for idx, line in enumerate(f):
line = line.strip()

View File

@ -109,7 +109,7 @@ def load_txt_info(gt_file, img_info):
img_info (dict): The dict of the img and annotation information
"""
with open(gt_file, 'r', encoding='utf-8') as f:
with open(gt_file, encoding='utf-8') as f:
anno_info = []
for line in f:
line = line.strip('\n')

View File

@ -108,7 +108,7 @@ def load_txt_info(gt_file, img_info):
Returns:
img_info (dict): The dict of the img and annotation information
"""
with open(gt_file, 'r', encoding='latin1') as f:
with open(gt_file, encoding='latin1') as f:
anno_info = []
for line in f:
line = line.strip('\n')

View File

@ -100,7 +100,7 @@ def load_txt_info(gt_file, img_info):
img_info (dict): The dict of the img and annotation information
"""
with open(gt_file, 'r') as f:
with open(gt_file) as f:
anno_info = []
annotations = f.readlines()
for ann in annotations:

View File

@ -170,7 +170,7 @@ def convert_hiertext(
raise Exception(
f'{annotation_path} not exists, please check and try again.')
annotation = json.load(open(annotation_path, 'r'))['annotations']
annotation = json.load(open(annotation_path))['annotations']
# outputs
dst_label_file = osp.join(root_path, f'{split}_label.{format}')
dst_image_root = osp.join(root_path, 'crops', split)

View File

@ -36,7 +36,6 @@ def convert_annotations(root_path, split, format):
with open(
osp.join(root_path, 'annotations',
f'Challenge1_{split}_Task3_GT.txt'),
'r',
encoding='"utf-8-sig') as f:
annos = f.readlines()
dst_image_root = osp.join(root_path, split.lower())

View File

@ -36,7 +36,6 @@ def convert_annotations(root_path, split, format):
with open(
osp.join(root_path, 'annotations',
f'Challenge2_{split}_Task3_GT.txt'),
'r',
encoding='"utf-8-sig') as f:
annos = f.readlines()
dst_image_root = osp.join(root_path, split.lower())

View File

@ -36,7 +36,6 @@ def convert_annotations(root_path, split, format):
lines = []
with open(
osp.join(root_path, f'{split}_label.txt'),
'r',
encoding='"utf-8-sig') as f:
annos = f.readlines()
for anno in annos:

View File

@ -138,7 +138,7 @@ def load_txt_info(gt_file, img_info):
"""
anno_info = []
with open(gt_file, 'r') as f:
with open(gt_file) as f:
lines = f.readlines()
for line in lines:
points = line.split(',')[0:8]

View File

@ -122,7 +122,7 @@ def load_txt_info(gt_file, img_info):
"""
anno_info = []
with open(gt_file, 'r', encoding='utf-8-sig') as f:
with open(gt_file, encoding='utf-8-sig') as f:
lines = f.readlines()
for line in lines:
points = line.split(',')[0:8]

View File

@ -112,7 +112,7 @@ def load_txt_info(gt_file, img_info):
img_info (list): The dict of the img and annotation information
"""
with open(gt_file, 'r', encoding='unicode_escape') as f:
with open(gt_file, encoding='unicode_escape') as f:
anno_info = []
for ann in f.readlines():
# skip invalid annotation line

View File

@ -41,7 +41,7 @@ def collect_files(img_dir, gt_dir):
imgs_list = sorted(imgs_list)
ann_list = sorted(
[osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir))
files = [(img_file, gt_file)
for (img_file, gt_file) in zip(imgs_list, ann_list)]
@ -218,7 +218,7 @@ def get_contours_txt(gt_path):
contours = []
words = []
with open(gt_path, 'r') as f:
with open(gt_path) as f:
tmp_line = ''
for idx, line in enumerate(f):
line = line.strip()

View File

@ -111,7 +111,7 @@ def load_txt_info(gt_file, img_info):
img_info (dict): The dict of the img and annotation information
"""
with open(gt_file, 'r', encoding='utf-8') as f:
with open(gt_file, encoding='utf-8') as f:
anno_info = []
for line in f:
line = line.strip('\n')

View File

@ -161,7 +161,7 @@ def onnx2tensorrt(onnx_file: str,
atol=1e-4):
same_diff = 'different'
break
print('The outputs are {} between TensorRT and ONNX'.format(same_diff))
print(f'The outputs are {same_diff} between TensorRT and ONNX')
if show:
onnx_img = onnx_model.show_result(
@ -259,7 +259,7 @@ if __name__ == '__main__':
assert osp.exists(args.model_config), 'Config {} not found.'.format(
args.model_config)
assert osp.exists(args.onnx_file), \
'ONNX model {} not found.'.format(args.onnx_file)
f'ONNX model {args.onnx_file} not found.'
assert args.workspace_size >= 0, 'Workspace size less than 0.'
for max_value, min_value in zip(args.max_shape, args.min_shape):
assert max_value >= min_value, \

View File

@ -259,7 +259,7 @@ def pytorch2onnx(model: nn.Module,
atol=1e-4):
same_diff = 'different'
break
print('The outputs are {} between PyTorch and ONNX'.format(same_diff))
print(f'The outputs are {same_diff} between PyTorch and ONNX')
if show:
onnx_img = onnx_model.show_result(

View File

@ -26,7 +26,7 @@ def process_checkpoint(in_file, out_file):
checkpoint['meta'] = {'CLASSES': 0}
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
sha = subprocess.check_output(['sha256sum', out_file]).decode()
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])