parent
7865207096
commit
8002ccf4b6
|
@ -88,15 +88,14 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(config, *mode, seed=None):
|
def build_dataloader(config, mode, seed=None):
|
||||||
dataloader_config = config["DataLoader"]
|
assert mode in [
|
||||||
for m in mode:
|
|
||||||
assert m in [
|
|
||||||
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
|
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
|
||||||
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
|
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
|
||||||
assert m in dataloader_config.keys(), "{} config not in yaml".format(m)
|
assert mode in config["DataLoader"].keys(), "{} config not in yaml".format(
|
||||||
dataloader_config = dataloader_config[m]
|
mode)
|
||||||
|
|
||||||
|
dataloader_config = config["DataLoader"][mode]
|
||||||
class_num = config["Arch"].get("class_num", None)
|
class_num = config["Arch"].get("class_num", None)
|
||||||
epochs = config["Global"]["epochs"]
|
epochs = config["Global"]["epochs"]
|
||||||
use_dali = config["Global"].get("use_dali", False)
|
use_dali = config["Global"].get("use_dali", False)
|
||||||
|
|
|
@ -22,7 +22,6 @@ from paddle import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ..utils.amp import AMPForwardDecorator
|
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
from ppcls.utils.logger import init_logger
|
from ppcls.utils.logger import init_logger
|
||||||
from ppcls.utils.config import print_config
|
from ppcls.utils.config import print_config
|
||||||
|
|
|
@ -13,17 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .classification import ClassEval
|
from .classification import ClassEval
|
||||||
from .retrieval import RetrievalEval
|
from .retrieval import retrieval_eval
|
||||||
from .adaface import adaface_eval
|
from .adaface import adaface_eval
|
||||||
|
|
||||||
|
|
||||||
def build_eval_func(config, mode, model):
|
def build_eval_func(config, mode, model):
|
||||||
if mode not in ["eval", "train"]:
|
if mode not in ["eval", "train"]:
|
||||||
return None
|
return None
|
||||||
task = config["Global"].get("task", "classification")
|
eval_mode = config["Global"].get("eval_mode", None)
|
||||||
if task == "classification":
|
if eval_mode is None:
|
||||||
|
config["Global"]["eval_mode"] = "classification"
|
||||||
return ClassEval(config, mode, model)
|
return ClassEval(config, mode, model)
|
||||||
elif task == "retrieval":
|
|
||||||
return RetrievalEval(config, mode, model)
|
|
||||||
else:
|
else:
|
||||||
raise Exception()
|
return getattr(sys.modules[__name__], eval_mode + "_eval")(config,
|
||||||
|
mode, model)
|
||||||
|
|
|
@ -21,50 +21,25 @@ import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import scipy
|
import scipy
|
||||||
|
|
||||||
from ...utils.misc import AverageMeter
|
from ppcls.utils import all_gather, logger
|
||||||
from ...utils import all_gather, logger
|
|
||||||
from ...data import build_dataloader
|
|
||||||
from ...loss import build_loss
|
|
||||||
from ...metric import build_metrics
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalEval(object):
|
def retrieval_eval(engine, epoch_id=0):
|
||||||
def __init__(self, config, mode, model):
|
engine.model.eval()
|
||||||
self.config = config
|
|
||||||
self.model = model
|
|
||||||
self.print_batch_step = self.config["Global"]["print_batch_step"]
|
|
||||||
self.use_dali = self.config["Global"].get("use_dali", False)
|
|
||||||
self.eval_metric_func = build_metrics(self.config, "Eval")
|
|
||||||
self.eval_loss_func = build_loss(self.config, "Eval")
|
|
||||||
self.output_info = dict()
|
|
||||||
|
|
||||||
self.gallery_query_dataloader = None
|
|
||||||
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
|
|
||||||
self.gallery_query_dataloader = build_dataloader(self.config,
|
|
||||||
"Eval")
|
|
||||||
else:
|
|
||||||
self.gallery_dataloader = build_dataloader(self.config, "Eval",
|
|
||||||
"Gallery")
|
|
||||||
self.query_dataloader = build_dataloader(self.config, "Eval",
|
|
||||||
"Query")
|
|
||||||
|
|
||||||
def __call__(self, epoch_id=0):
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# step1. prepare query and gallery features
|
# step1. prepare query and gallery features
|
||||||
if self.gallery_query_dataloader is not None:
|
if engine.gallery_query_dataloader is not None:
|
||||||
gallery_feat, gallery_label, gallery_camera = self.compute_feature(
|
gallery_feat, gallery_label, gallery_camera = compute_feature(
|
||||||
"gallery_query")
|
engine, "gallery_query")
|
||||||
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
|
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
|
||||||
else:
|
else:
|
||||||
gallery_feat, gallery_label, gallery_camera = self.compute_feature(
|
gallery_feat, gallery_label, gallery_camera = compute_feature(
|
||||||
"gallery")
|
engine, "gallery")
|
||||||
query_feat, query_label, query_camera = self.compute_feature(
|
query_feat, query_label, query_camera = compute_feature(engine,
|
||||||
"query")
|
"query")
|
||||||
|
|
||||||
# step2. split features into feature blocks for saving memory
|
# step2. split features into feature blocks for saving memory
|
||||||
num_query = len(query_feat)
|
num_query = len(query_feat)
|
||||||
block_size = self.config["Global"].get("sim_block_size", 64)
|
block_size = engine.config["Global"].get("sim_block_size", 64)
|
||||||
sections = [block_size] * (num_query // block_size)
|
sections = [block_size] * (num_query // block_size)
|
||||||
if num_query % block_size > 0:
|
if num_query % block_size > 0:
|
||||||
sections.append(num_query % block_size)
|
sections.append(num_query % block_size)
|
||||||
|
@ -76,15 +51,15 @@ class RetrievalEval(object):
|
||||||
metric_key = None
|
metric_key = None
|
||||||
|
|
||||||
# step3. compute metric
|
# step3. compute metric
|
||||||
if self.eval_loss_func is None:
|
if engine.eval_loss_func is None:
|
||||||
metric_dict = {metric_key: 0.0}
|
metric_dict = {metric_key: 0.0}
|
||||||
else:
|
else:
|
||||||
use_reranking = self.config["Global"].get("re_ranking", False)
|
use_reranking = engine.config["Global"].get("re_ranking", False)
|
||||||
logger.info(f"re_ranking={use_reranking}")
|
logger.info(f"re_ranking={use_reranking}")
|
||||||
if use_reranking:
|
if use_reranking:
|
||||||
# compute distance matrix
|
# compute distance matrix
|
||||||
distmat = compute_re_ranking_dist(
|
distmat = compute_re_ranking_dist(
|
||||||
query_feat, gallery_feat, self.config["Global"].get(
|
query_feat, gallery_feat, engine.config["Global"].get(
|
||||||
"feature_normalize", True), 20, 6, 0.3)
|
"feature_normalize", True), 20, 6, 0.3)
|
||||||
# exclude illegal distance
|
# exclude illegal distance
|
||||||
if query_camera is not None:
|
if query_camera is not None:
|
||||||
|
@ -92,12 +67,11 @@ class RetrievalEval(object):
|
||||||
label_mask = query_label != gallery_label.t()
|
label_mask = query_label != gallery_label.t()
|
||||||
keep_mask = label_mask | camera_mask
|
keep_mask = label_mask | camera_mask
|
||||||
distmat = keep_mask.astype(query_feat.dtype) * distmat + (
|
distmat = keep_mask.astype(query_feat.dtype) * distmat + (
|
||||||
~keep_mask).astype(query_feat.dtype) * (distmat.max() +
|
~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
|
||||||
1)
|
|
||||||
else:
|
else:
|
||||||
keep_mask = None
|
keep_mask = None
|
||||||
# compute metric with all samples
|
# compute metric with all samples
|
||||||
metric_dict = self.eval_metric_func(-distmat, query_label,
|
metric_dict = engine.eval_metric_func(-distmat, query_label,
|
||||||
gallery_label, keep_mask)
|
gallery_label, keep_mask)
|
||||||
else:
|
else:
|
||||||
metric_dict = defaultdict(float)
|
metric_dict = defaultdict(float)
|
||||||
|
@ -116,13 +90,13 @@ class RetrievalEval(object):
|
||||||
else:
|
else:
|
||||||
keep_mask = None
|
keep_mask = None
|
||||||
# compute metric by block
|
# compute metric by block
|
||||||
metric_block = self.eval_metric_func(
|
metric_block = engine.eval_metric_func(
|
||||||
distmat, query_label_blocks[block_idx], gallery_label,
|
distmat, query_label_blocks[block_idx], gallery_label,
|
||||||
keep_mask)
|
keep_mask)
|
||||||
# accumulate metric
|
# accumulate metric
|
||||||
for key in metric_block:
|
for key in metric_block:
|
||||||
metric_dict[key] += metric_block[
|
metric_dict[key] += metric_block[key] * block_feat.shape[
|
||||||
key] * block_feat.shape[0] / num_query
|
0] / num_query
|
||||||
|
|
||||||
metric_info_list = []
|
metric_info_list = []
|
||||||
for key, value in metric_dict.items():
|
for key, value in metric_dict.items():
|
||||||
|
@ -134,13 +108,14 @@ class RetrievalEval(object):
|
||||||
|
|
||||||
return metric_dict[metric_key]
|
return metric_dict[metric_key]
|
||||||
|
|
||||||
def compute_feature(self, name="gallery"):
|
|
||||||
|
def compute_feature(engine, name="gallery"):
|
||||||
if name == "gallery":
|
if name == "gallery":
|
||||||
dataloader = self.gallery_dataloader
|
dataloader = engine.gallery_dataloader
|
||||||
elif name == "query":
|
elif name == "query":
|
||||||
dataloader = self.query_dataloader
|
dataloader = engine.query_dataloader
|
||||||
elif name == "gallery_query":
|
elif name == "gallery_query":
|
||||||
dataloader = self.gallery_query_dataloader
|
dataloader = engine.gallery_query_dataloader
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Only support gallery or query or gallery_query dataset, but got {name}"
|
f"Only support gallery or query or gallery_query dataset, but got {name}"
|
||||||
|
@ -151,7 +126,7 @@ class RetrievalEval(object):
|
||||||
all_camera = []
|
all_camera = []
|
||||||
has_camera = False
|
has_camera = False
|
||||||
for idx, batch in enumerate(dataloader): # load is very time-consuming
|
for idx, batch in enumerate(dataloader): # load is very time-consuming
|
||||||
if idx % self.print_batch_step == 0:
|
if idx % engine.config["Global"]["print_batch_step"] == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
|
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
|
||||||
)
|
)
|
||||||
|
@ -161,14 +136,20 @@ class RetrievalEval(object):
|
||||||
if len(batch) >= 3:
|
if len(batch) >= 3:
|
||||||
has_camera = True
|
has_camera = True
|
||||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||||
|
if engine.amp and engine.amp_eval:
|
||||||
out = self.model(batch)
|
with paddle.amp.auto_cast(
|
||||||
|
custom_black_list={
|
||||||
|
"flatten_contiguous_range", "greater_than"
|
||||||
|
},
|
||||||
|
level=engine.amp_level):
|
||||||
|
out = engine.model(batch[0])
|
||||||
|
else:
|
||||||
|
out = engine.model(batch[0])
|
||||||
if "Student" in out:
|
if "Student" in out:
|
||||||
out = out["Student"]
|
out = out["Student"]
|
||||||
|
|
||||||
# get features
|
# get features
|
||||||
if self.config["Global"].get("retrieval_feature_from",
|
if engine.config["Global"].get("retrieval_feature_from",
|
||||||
"features") == "features":
|
"features") == "features":
|
||||||
# use output from neck as feature
|
# use output from neck as feature
|
||||||
batch_feat = out["features"]
|
batch_feat = out["features"]
|
||||||
|
@ -177,14 +158,13 @@ class RetrievalEval(object):
|
||||||
batch_feat = out["backbone"]
|
batch_feat = out["backbone"]
|
||||||
|
|
||||||
# do norm(optional)
|
# do norm(optional)
|
||||||
if self.config["Global"].get("feature_normalize", True):
|
if engine.config["Global"].get("feature_normalize", True):
|
||||||
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
|
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
|
||||||
|
|
||||||
# do binarize(optional)
|
# do binarize(optional)
|
||||||
if self.config["Global"].get("feature_binarize") == "round":
|
if engine.config["Global"].get("feature_binarize") == "round":
|
||||||
batch_feat = paddle.round(batch_feat).astype(
|
batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
|
||||||
"float32") * 2.0 - 1.0
|
elif engine.config["Global"].get("feature_binarize") == "sign":
|
||||||
elif self.config["Global"].get("feature_binarize") == "sign":
|
|
||||||
batch_feat = paddle.sign(batch_feat).astype("float32")
|
batch_feat = paddle.sign(batch_feat).astype("float32")
|
||||||
|
|
||||||
if paddle.distributed.get_world_size() > 1:
|
if paddle.distributed.get_world_size() > 1:
|
||||||
|
@ -198,7 +178,7 @@ class RetrievalEval(object):
|
||||||
if has_camera:
|
if has_camera:
|
||||||
all_camera.append(batch[2])
|
all_camera.append(batch[2])
|
||||||
|
|
||||||
if self.use_dali:
|
if engine.use_dali:
|
||||||
dataloader.reset()
|
dataloader.reset()
|
||||||
|
|
||||||
all_feat = paddle.concat(all_feat)
|
all_feat = paddle.concat(all_feat)
|
||||||
|
@ -208,7 +188,7 @@ class RetrievalEval(object):
|
||||||
else:
|
else:
|
||||||
all_camera = None
|
all_camera = None
|
||||||
# discard redundant padding sample(s) at the end
|
# discard redundant padding sample(s) at the end
|
||||||
total_samples = dataloader.size if self.use_dali else len(
|
total_samples = dataloader.size if engine.use_dali else len(
|
||||||
dataloader.dataset)
|
dataloader.dataset)
|
||||||
all_feat = all_feat[:total_samples]
|
all_feat = all_feat[:total_samples]
|
||||||
all_label = all_label[:total_samples]
|
all_label = all_label[:total_samples]
|
||||||
|
|
|
@ -22,8 +22,9 @@ from .train_progressive import train_epoch_progressive
|
||||||
def build_train_func(config, mode, model, eval_func):
|
def build_train_func(config, mode, model, eval_func):
|
||||||
if mode != "train":
|
if mode != "train":
|
||||||
return None
|
return None
|
||||||
task = config["Global"].get("task", "classification")
|
train_mode = config["Global"].get("task", None)
|
||||||
if task == "classification" or task == "retrieval":
|
if train_mode is None:
|
||||||
|
config["Global"]["task"] = "classification"
|
||||||
return ClassTrainer(config, model, eval_func)
|
return ClassTrainer(config, model, eval_func)
|
||||||
else:
|
else:
|
||||||
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
|
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
|
||||||
|
|
|
@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from ppcls.data import build_dataloader
|
from ppcls.data import build_dataloader
|
||||||
from ppcls.utils import logger, type_name
|
from ppcls.utils import logger, type_name
|
||||||
from .classification import ClassTrainer
|
from .regular_train_epoch import regular_train_epoch
|
||||||
|
|
||||||
|
|
||||||
def train_epoch_progressive(engine, epoch_id, print_batch_step):
|
def train_epoch_progressive(engine, epoch_id, print_batch_step):
|
||||||
|
|
Loading…
Reference in New Issue