parent
8312dc3677
commit
f4de50226a
|
@ -36,11 +36,15 @@ class Topk(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
class_id_map = {}
|
class_id_map = {}
|
||||||
with open(class_id_map_file, "r") as fin:
|
try:
|
||||||
lines = fin.readlines()
|
with open(class_id_map_file, "r", encoding='utf-8') as fin:
|
||||||
for line in lines:
|
lines = fin.readlines()
|
||||||
partition = line.split("\n")[0].partition(self.delimiter)
|
except Exception as e:
|
||||||
class_id_map[int(partition[0])] = str(partition[-1])
|
with open(class_id_map_file, "r", encoding='gbk') as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for line in lines:
|
||||||
|
partition = line.split("\n")[0].partition(self.delimiter)
|
||||||
|
class_id_map[int(partition[0])] = str(partition[-1])
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
class_id_map = None
|
class_id_map = None
|
||||||
|
|
|
@ -18,19 +18,28 @@ import base64
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_image_list(img_file):
|
def get_image_list(img_file, infer_list=None):
|
||||||
imgs_lists = []
|
imgs_lists = []
|
||||||
if img_file is None or not os.path.exists(img_file):
|
if infer_list and not os.path.exists(infer_list):
|
||||||
raise Exception("not found any img file in {}".format(img_file))
|
raise Exception("not found infer list {}".format(infer_list))
|
||||||
|
if infer_list:
|
||||||
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
with open(infer_list, "r") as f:
|
||||||
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
lines = f.readlines()
|
||||||
imgs_lists.append(img_file)
|
for line in lines:
|
||||||
elif os.path.isdir(img_file):
|
image_path = line.strip(" ").split()[0]
|
||||||
for root, dirs, files in os.walk(img_file):
|
image_path = os.path.join(img_file, image_path)
|
||||||
for single_file in files:
|
imgs_lists.append(image_path)
|
||||||
if single_file.split('.')[-1] in img_end:
|
else:
|
||||||
imgs_lists.append(os.path.join(root, single_file))
|
if img_file is None or not os.path.exists(img_file):
|
||||||
|
raise Exception("not found any img file in {}".format(img_file))
|
||||||
|
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||||
|
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||||
|
imgs_lists.append(img_file)
|
||||||
|
elif os.path.isdir(img_file):
|
||||||
|
for root, dirs, files in os.walk(img_file):
|
||||||
|
for single_file in files:
|
||||||
|
if single_file.split('.')[-1] in img_end:
|
||||||
|
imgs_lists.append(os.path.join(root, single_file))
|
||||||
if len(imgs_lists) == 0:
|
if len(imgs_lists) == 0:
|
||||||
raise Exception("not found any img file in {}".format(img_file))
|
raise Exception("not found any img file in {}".format(img_file))
|
||||||
imgs_lists = sorted(imgs_lists)
|
imgs_lists = sorted(imgs_lists)
|
||||||
|
|
|
@ -442,7 +442,9 @@ class Engine(object):
|
||||||
results = []
|
results = []
|
||||||
total_trainer = dist.get_world_size()
|
total_trainer = dist.get_world_size()
|
||||||
local_rank = dist.get_rank()
|
local_rank = dist.get_rank()
|
||||||
image_list = get_image_list(self.config["Infer"]["infer_imgs"])
|
infer_imgs = self.config["Infer"]["infer_imgs"]
|
||||||
|
infer_list = self.config["Infer"].get("infer_list", None)
|
||||||
|
image_list = get_image_list(infer_imgs, infer_list=infer_list)
|
||||||
# data split
|
# data split
|
||||||
image_list = image_list[local_rank::total_trainer]
|
image_list = image_list[local_rank::total_trainer]
|
||||||
|
|
||||||
|
@ -450,6 +452,7 @@ class Engine(object):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
batch_data = []
|
batch_data = []
|
||||||
image_file_list = []
|
image_file_list = []
|
||||||
|
save_path = self.config["Infer"].get("save_dir", None)
|
||||||
for idx, image_file in enumerate(image_list):
|
for idx, image_file in enumerate(image_list):
|
||||||
with open(image_file, 'rb') as f:
|
with open(image_file, 'rb') as f:
|
||||||
x = f.read()
|
x = f.read()
|
||||||
|
@ -473,11 +476,11 @@ class Engine(object):
|
||||||
out = out["output"]
|
out = out["output"]
|
||||||
|
|
||||||
result = self.postprocess_func(out, image_file_list)
|
result = self.postprocess_func(out, image_file_list)
|
||||||
logger.info(result)
|
if not save_path:
|
||||||
|
logger.info(result)
|
||||||
results.extend(result)
|
results.extend(result)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
image_file_list.clear()
|
image_file_list.clear()
|
||||||
save_path = self.config["Infer"].get("save_dir", None)
|
|
||||||
if save_path:
|
if save_path:
|
||||||
save_predict_result(save_path, results)
|
save_predict_result(save_path, results)
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -24,12 +24,9 @@ def save_predict_result(save_path, result):
|
||||||
elif os.path.splitext(save_path)[-1] == '.json':
|
elif os.path.splitext(save_path)[-1] == '.json':
|
||||||
save_path = save_path
|
save_path = save_path
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
raise Exception(f"{save_path} is invalid input path, only files in json format are supported.")
|
||||||
f"{save_path} is invalid input path, only files in json format are supported."
|
|
||||||
)
|
|
||||||
if os.path.exists(save_path):
|
if os.path.exists(save_path):
|
||||||
logger.warning(
|
logger.warning(f"The file {save_path} will be overwritten.")
|
||||||
f"The file {save_path} will be overwritten."
|
|
||||||
)
|
|
||||||
with open(save_path, 'w', encoding='utf-8') as f:
|
with open(save_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(result, f)
|
json.dump(result, f)
|
||||||
|
|
Loading…
Reference in New Issue