fix gbk (#2941)
parent
b3fcc98610
commit
74a33b7f50
|
@ -36,8 +36,12 @@ class Topk(object):
|
|||
|
||||
try:
|
||||
class_id_map = {}
|
||||
with open(class_id_map_file, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
try:
|
||||
with open(class_id_map_file, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
except Exception as e:
|
||||
with open(class_id_map_file, "r", encoding='gbk') as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
partition = line.split("\n")[0].partition(self.delimiter)
|
||||
class_id_map[int(partition[0])] = str(partition[-1])
|
||||
|
|
|
@ -18,19 +18,28 @@ import base64
|
|||
import numpy as np
|
||||
|
||||
|
||||
def get_image_list(img_file):
|
||||
def get_image_list(img_file, infer_list=None):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for root, dirs, files in os.walk(img_file):
|
||||
for single_file in files:
|
||||
if single_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(os.path.join(root, single_file))
|
||||
if infer_list and not os.path.exists(infer_list):
|
||||
raise Exception("not found infer list {}".format(infer_list))
|
||||
if infer_list:
|
||||
with open(infer_list, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
image_path = line.strip(" ").split()[0]
|
||||
image_path = os.path.join(img_file, image_path)
|
||||
imgs_lists.append(image_path)
|
||||
else:
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for root, dirs, files in os.walk(img_file):
|
||||
for single_file in files:
|
||||
if single_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(os.path.join(root, single_file))
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
|
|
|
@ -442,7 +442,9 @@ class Engine(object):
|
|||
results = []
|
||||
total_trainer = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
image_list = get_image_list(self.config["Infer"]["infer_imgs"])
|
||||
infer_imgs = self.config["Infer"]["infer_imgs"]
|
||||
infer_list = self.config["Infer"].get("infer_list", None)
|
||||
image_list = get_image_list(infer_imgs, infer_list=infer_list)
|
||||
# data split
|
||||
image_list = image_list[local_rank::total_trainer]
|
||||
|
||||
|
@ -450,6 +452,7 @@ class Engine(object):
|
|||
self.model.eval()
|
||||
batch_data = []
|
||||
image_file_list = []
|
||||
save_path = self.config["Infer"].get("save_dir", None)
|
||||
for idx, image_file in enumerate(image_list):
|
||||
with open(image_file, 'rb') as f:
|
||||
x = f.read()
|
||||
|
@ -473,11 +476,11 @@ class Engine(object):
|
|||
out = out["output"]
|
||||
|
||||
result = self.postprocess_func(out, image_file_list)
|
||||
logger.info(result)
|
||||
if not save_path:
|
||||
logger.info(result)
|
||||
results.extend(result)
|
||||
batch_data.clear()
|
||||
image_file_list.clear()
|
||||
save_path = self.config["Infer"].get("save_dir", None)
|
||||
if save_path:
|
||||
save_predict_result(save_path, results)
|
||||
return results
|
||||
|
|
|
@ -24,12 +24,9 @@ def save_predict_result(save_path, result):
|
|||
elif os.path.splitext(save_path)[-1] == '.json':
|
||||
save_path = save_path
|
||||
else:
|
||||
logger.warning(
|
||||
f"{save_path} is invalid input path, only files in json format are supported."
|
||||
)
|
||||
raise Exception(f"{save_path} is invalid input path, only files in json format are supported.")
|
||||
|
||||
if os.path.exists(save_path):
|
||||
logger.warning(
|
||||
f"The file {save_path} will be overwritten."
|
||||
)
|
||||
logger.warning(f"The file {save_path} will be overwritten.")
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f)
|
||||
|
|
Loading…
Reference in New Issue