Revert "support ShiTu"

This reverts commit 9beb154bc3.
pull/2701/head
Tingquan Gao 2023-03-14 16:16:40 +08:00
parent 7865207096
commit 8002ccf4b6
6 changed files with 176 additions and 197 deletions

View File

@ -88,15 +88,14 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build_dataloader(config, *mode, seed=None): def build_dataloader(config, mode, seed=None):
dataloader_config = config["DataLoader"] assert mode in [
for m in mode:
assert m in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert m in dataloader_config.keys(), "{} config not in yaml".format(m) assert mode in config["DataLoader"].keys(), "{} config not in yaml".format(
dataloader_config = dataloader_config[m] mode)
dataloader_config = config["DataLoader"][mode]
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
epochs = config["Global"]["epochs"] epochs = config["Global"]["epochs"]
use_dali = config["Global"].get("use_dali", False) use_dali = config["Global"].get("use_dali", False)

View File

@ -22,7 +22,6 @@ from paddle import nn
import numpy as np import numpy as np
import random import random
from ..utils.amp import AMPForwardDecorator
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config

View File

@ -13,17 +13,17 @@
# limitations under the License. # limitations under the License.
from .classification import ClassEval from .classification import ClassEval
from .retrieval import RetrievalEval from .retrieval import retrieval_eval
from .adaface import adaface_eval from .adaface import adaface_eval
def build_eval_func(config, mode, model): def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]: if mode not in ["eval", "train"]:
return None return None
task = config["Global"].get("task", "classification") eval_mode = config["Global"].get("eval_mode", None)
if task == "classification": if eval_mode is None:
config["Global"]["eval_mode"] = "classification"
return ClassEval(config, mode, model) return ClassEval(config, mode, model)
elif task == "retrieval":
return RetrievalEval(config, mode, model)
else: else:
raise Exception() return getattr(sys.modules[__name__], eval_mode + "_eval")(config,
mode, model)

View File

@ -21,50 +21,25 @@ import numpy as np
import paddle import paddle
import scipy import scipy
from ...utils.misc import AverageMeter from ppcls.utils import all_gather, logger
from ...utils import all_gather, logger
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
class RetrievalEval(object): def retrieval_eval(engine, epoch_id=0):
def __init__(self, config, mode, model): engine.model.eval()
self.config = config
self.model = model
self.print_batch_step = self.config["Global"]["print_batch_step"]
self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(self.config, "Eval")
self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict()
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
self.gallery_query_dataloader = build_dataloader(self.config,
"Eval")
else:
self.gallery_dataloader = build_dataloader(self.config, "Eval",
"Gallery")
self.query_dataloader = build_dataloader(self.config, "Eval",
"Query")
def __call__(self, epoch_id=0):
self.model.eval()
# step1. prepare query and gallery features # step1. prepare query and gallery features
if self.gallery_query_dataloader is not None: if engine.gallery_query_dataloader is not None:
gallery_feat, gallery_label, gallery_camera = self.compute_feature( gallery_feat, gallery_label, gallery_camera = compute_feature(
"gallery_query") engine, "gallery_query")
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
else: else:
gallery_feat, gallery_label, gallery_camera = self.compute_feature( gallery_feat, gallery_label, gallery_camera = compute_feature(
"gallery") engine, "gallery")
query_feat, query_label, query_camera = self.compute_feature( query_feat, query_label, query_camera = compute_feature(engine,
"query") "query")
# step2. split features into feature blocks for saving memory # step2. split features into feature blocks for saving memory
num_query = len(query_feat) num_query = len(query_feat)
block_size = self.config["Global"].get("sim_block_size", 64) block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (num_query // block_size) sections = [block_size] * (num_query // block_size)
if num_query % block_size > 0: if num_query % block_size > 0:
sections.append(num_query % block_size) sections.append(num_query % block_size)
@ -76,15 +51,15 @@ class RetrievalEval(object):
metric_key = None metric_key = None
# step3. compute metric # step3. compute metric
if self.eval_loss_func is None: if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.0} metric_dict = {metric_key: 0.0}
else: else:
use_reranking = self.config["Global"].get("re_ranking", False) use_reranking = engine.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}") logger.info(f"re_ranking={use_reranking}")
if use_reranking: if use_reranking:
# compute distance matrix # compute distance matrix
distmat = compute_re_ranking_dist( distmat = compute_re_ranking_dist(
query_feat, gallery_feat, self.config["Global"].get( query_feat, gallery_feat, engine.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3) "feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance # exclude illegal distance
if query_camera is not None: if query_camera is not None:
@ -92,12 +67,11 @@ class RetrievalEval(object):
label_mask = query_label != gallery_label.t() label_mask = query_label != gallery_label.t()
keep_mask = label_mask | camera_mask keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat + ( distmat = keep_mask.astype(query_feat.dtype) * distmat + (
~keep_mask).astype(query_feat.dtype) * (distmat.max() + ~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
1)
else: else:
keep_mask = None keep_mask = None
# compute metric with all samples # compute metric with all samples
metric_dict = self.eval_metric_func(-distmat, query_label, metric_dict = engine.eval_metric_func(-distmat, query_label,
gallery_label, keep_mask) gallery_label, keep_mask)
else: else:
metric_dict = defaultdict(float) metric_dict = defaultdict(float)
@ -116,13 +90,13 @@ class RetrievalEval(object):
else: else:
keep_mask = None keep_mask = None
# compute metric by block # compute metric by block
metric_block = self.eval_metric_func( metric_block = engine.eval_metric_func(
distmat, query_label_blocks[block_idx], gallery_label, distmat, query_label_blocks[block_idx], gallery_label,
keep_mask) keep_mask)
# accumulate metric # accumulate metric
for key in metric_block: for key in metric_block:
metric_dict[key] += metric_block[ metric_dict[key] += metric_block[key] * block_feat.shape[
key] * block_feat.shape[0] / num_query 0] / num_query
metric_info_list = [] metric_info_list = []
for key, value in metric_dict.items(): for key, value in metric_dict.items():
@ -134,13 +108,14 @@ class RetrievalEval(object):
return metric_dict[metric_key] return metric_dict[metric_key]
def compute_feature(self, name="gallery"):
def compute_feature(engine, name="gallery"):
if name == "gallery": if name == "gallery":
dataloader = self.gallery_dataloader dataloader = engine.gallery_dataloader
elif name == "query": elif name == "query":
dataloader = self.query_dataloader dataloader = engine.query_dataloader
elif name == "gallery_query": elif name == "gallery_query":
dataloader = self.gallery_query_dataloader dataloader = engine.gallery_query_dataloader
else: else:
raise ValueError( raise ValueError(
f"Only support gallery or query or gallery_query dataset, but got {name}" f"Only support gallery or query or gallery_query dataset, but got {name}"
@ -151,7 +126,7 @@ class RetrievalEval(object):
all_camera = [] all_camera = []
has_camera = False has_camera = False
for idx, batch in enumerate(dataloader): # load is very time-consuming for idx, batch in enumerate(dataloader): # load is very time-consuming
if idx % self.print_batch_step == 0: if idx % engine.config["Global"]["print_batch_step"] == 0:
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
) )
@ -161,14 +136,20 @@ class RetrievalEval(object):
if len(batch) >= 3: if len(batch) >= 3:
has_camera = True has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval:
out = self.model(batch) with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch[0])
else:
out = engine.model(batch[0])
if "Student" in out: if "Student" in out:
out = out["Student"] out = out["Student"]
# get features # get features
if self.config["Global"].get("retrieval_feature_from", if engine.config["Global"].get("retrieval_feature_from",
"features") == "features": "features") == "features":
# use output from neck as feature # use output from neck as feature
batch_feat = out["features"] batch_feat = out["features"]
@ -177,14 +158,13 @@ class RetrievalEval(object):
batch_feat = out["backbone"] batch_feat = out["backbone"]
# do norm(optional) # do norm(optional)
if self.config["Global"].get("feature_normalize", True): if engine.config["Global"].get("feature_normalize", True):
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2) batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
# do binarize(optional) # do binarize(optional)
if self.config["Global"].get("feature_binarize") == "round": if engine.config["Global"].get("feature_binarize") == "round":
batch_feat = paddle.round(batch_feat).astype( batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
"float32") * 2.0 - 1.0 elif engine.config["Global"].get("feature_binarize") == "sign":
elif self.config["Global"].get("feature_binarize") == "sign":
batch_feat = paddle.sign(batch_feat).astype("float32") batch_feat = paddle.sign(batch_feat).astype("float32")
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
@ -198,7 +178,7 @@ class RetrievalEval(object):
if has_camera: if has_camera:
all_camera.append(batch[2]) all_camera.append(batch[2])
if self.use_dali: if engine.use_dali:
dataloader.reset() dataloader.reset()
all_feat = paddle.concat(all_feat) all_feat = paddle.concat(all_feat)
@ -208,7 +188,7 @@ class RetrievalEval(object):
else: else:
all_camera = None all_camera = None
# discard redundant padding sample(s) at the end # discard redundant padding sample(s) at the end
total_samples = dataloader.size if self.use_dali else len( total_samples = dataloader.size if engine.use_dali else len(
dataloader.dataset) dataloader.dataset)
all_feat = all_feat[:total_samples] all_feat = all_feat[:total_samples]
all_label = all_label[:total_samples] all_label = all_label[:total_samples]

View File

@ -22,8 +22,9 @@ from .train_progressive import train_epoch_progressive
def build_train_func(config, mode, model, eval_func): def build_train_func(config, mode, model, eval_func):
if mode != "train": if mode != "train":
return None return None
task = config["Global"].get("task", "classification") train_mode = config["Global"].get("task", None)
if task == "classification" or task == "retrieval": if train_mode is None:
config["Global"]["task"] = "classification"
return ClassTrainer(config, model, eval_func) return ClassTrainer(config, model, eval_func)
else: else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(

View File

@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.utils import logger, type_name from ppcls.utils import logger, type_name
from .classification import ClassTrainer from .regular_train_epoch import regular_train_epoch
def train_epoch_progressive(engine, epoch_id, print_batch_step): def train_epoch_progressive(engine, epoch_id, print_batch_step):