fix rec post and visualizer
parent
3579f5a612
commit
a67cdaa1ee
|
@ -42,7 +42,7 @@ def split_datafile(data_file, image_root, delimiter="\t"):
|
|||
for i, line in enumerate(lines):
|
||||
line = line.strip().split(delimiter)
|
||||
image_file = os.path.join(image_root, line[0])
|
||||
image_doc = line[1]
|
||||
image_doc = line[1]
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(image_doc)
|
||||
|
||||
|
@ -57,28 +57,34 @@ class GalleryBuilder(object):
|
|||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||
self.build(config['IndexProcess'])
|
||||
|
||||
|
||||
def build(self, config):
|
||||
'''
|
||||
build index from scratch
|
||||
'''
|
||||
gallery_images, gallery_docs = split_datafile(config['data_file'],
|
||||
config['image_root'], config['delimiter'])
|
||||
gallery_images, gallery_docs = split_datafile(
|
||||
config['data_file'], config['image_root'], config['delimiter'])
|
||||
|
||||
# extract gallery features
|
||||
gallery_features = np.zeros([len(gallery_images),
|
||||
config['embedding_size']], dtype=np.float32)
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
||||
|
||||
for i, image_file in enumerate(tqdm(gallery_images)):
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.error("img empty, please check {}".format(image_file))
|
||||
exit()
|
||||
img = img[:, :, ::-1]
|
||||
rec_feat = self.rec_predictor.predict(img)
|
||||
gallery_features[i,:] = rec_feat
|
||||
gallery_features[i, :] = rec_feat
|
||||
|
||||
# train index
|
||||
self.Searcher = Graph_Index(dist_type=config['dist_type'])
|
||||
self.Searcher.build(gallery_vectors=gallery_features, gallery_docs=gallery_docs,
|
||||
pq_size=config['pq_size'], index_path=config['index_path'])
|
||||
|
||||
self.Searcher = Graph_Index(dist_type=config['dist_type'])
|
||||
self.Searcher.build(
|
||||
gallery_vectors=gallery_features,
|
||||
gallery_docs=gallery_docs,
|
||||
pq_size=config['pq_size'],
|
||||
index_path=config['index_path'])
|
||||
|
||||
|
||||
def main(config):
|
||||
system_builder = GalleryBuilder(config)
|
||||
|
|
|
@ -46,22 +46,38 @@ class SystemPredictor(object):
|
|||
dist_type=config['IndexProcess']['dist_type'])
|
||||
self.Searcher.load(config['IndexProcess']['index_path'])
|
||||
|
||||
def append_self(self, results, shape):
|
||||
results.append({
|
||||
"class_id": 0,
|
||||
"score": 1.0,
|
||||
"bbox": np.array([0, 0, shape[1], shape[0]]),
|
||||
"label_name": "foreground",
|
||||
})
|
||||
return results
|
||||
|
||||
def predict(self, img):
|
||||
output = []
|
||||
results = self.det_predictor.predict(img)
|
||||
# add the whole image for recognition
|
||||
results = self.append_self(results, img.shape)
|
||||
|
||||
for result in results:
|
||||
preds = {}
|
||||
xmin, ymin, xmax, ymax = result["bbox"].astype("int")
|
||||
crop_img = img[ymin:ymax, xmin:xmax, :].copy()
|
||||
rec_results = self.rec_predictor.predict(crop_img)
|
||||
#preds["feature"] = rec_results
|
||||
preds["bbox"] = [xmin, ymin, xmax, ymax]
|
||||
scores, docs = self.Searcher.search(
|
||||
query=rec_results,
|
||||
return_k=self.return_k,
|
||||
search_budget=self.search_budget)
|
||||
preds["rec_docs"] = docs
|
||||
preds["rec_scores"] = scores
|
||||
# just top-1 result will be returned for the final
|
||||
if scores[0] >= self.config["IndexProcess"]["score_thres"]:
|
||||
preds["rec_docs"] = docs[0]
|
||||
preds["rec_scores"] = scores[0]
|
||||
else:
|
||||
preds["rec_docs"] = None
|
||||
preds["rec_scores"] = 0.0
|
||||
|
||||
output.append(preds)
|
||||
return output
|
||||
|
@ -75,7 +91,7 @@ def main(config):
|
|||
for idx, image_file in enumerate(image_list):
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
output = system_predictor.predict(img)
|
||||
draw_bbox_results(img[:, :, ::-1], output, image_file)
|
||||
draw_bbox_results(img, output, image_file)
|
||||
print(output)
|
||||
return
|
||||
|
||||
|
|
|
@ -15,18 +15,45 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def draw_bbox_results(image, results, input_path, save_dir=None):
|
||||
def draw_bbox_results(image,
|
||||
results,
|
||||
input_path,
|
||||
font_path="./utils/simfang.ttf",
|
||||
save_dir=None):
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
draw = ImageDraw.Draw(image)
|
||||
font = ImageFont.truetype(font_path, 20, encoding="utf-8")
|
||||
|
||||
color = (0, 255, 0)
|
||||
|
||||
for result in results:
|
||||
[xmin, ymin, xmax, ymax] = result["bbox"]
|
||||
# empty results
|
||||
if result["rec_docs"] is None:
|
||||
continue
|
||||
|
||||
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0),
|
||||
2)
|
||||
xmin, ymin, xmax, ymax = result["bbox"]
|
||||
text = "{}, {:.2f}".format(result["rec_docs"], result["rec_scores"])
|
||||
th = 20
|
||||
tw = int(len(result["rec_docs"]) * 20) + 60
|
||||
start_y = max(0, ymin - th)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, start_y), (xmin + tw + 1, start_y + th)],
|
||||
outline=color)
|
||||
|
||||
draw.text((xmin + 1, start_y), text, fill=color, font=font)
|
||||
|
||||
draw.rectangle(
|
||||
[(xmin, ymin), (xmax, ymax)], outline=(255, 0, 0), width=2)
|
||||
|
||||
image_name = os.path.basename(input_path)
|
||||
if save_dir is None:
|
||||
save_dir = "output"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
output_path = os.path.join(save_dir, image_name)
|
||||
cv2.imwrite(output_path, image)
|
||||
|
||||
image.save(output_path, quality=95)
|
||||
return np.array(image)
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue