parent
7243f1429b
commit
f42719afbb
|
@ -220,21 +220,23 @@ class DataIterator(object):
|
|||
return batch
|
||||
|
||||
|
||||
def build_dataloader(config, mode):
|
||||
if "class_num" in config["Global"]:
|
||||
global_class_num = config["Global"]["class_num"]
|
||||
def build_dataloader(engine):
|
||||
if "class_num" in engine.config["Global"]:
|
||||
global_class_num = engine.config["Global"]["class_num"]
|
||||
if "class_num" not in config["Arch"]:
|
||||
config["Arch"]["class_num"] = global_class_num
|
||||
engine.config["Arch"]["class_num"] = global_class_num
|
||||
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
|
||||
else:
|
||||
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
|
||||
logger.warning(msg)
|
||||
|
||||
class_num = config["Arch"].get("class_num", None)
|
||||
config["DataLoader"].update({"class_num": class_num})
|
||||
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
|
||||
class_num = engine.config["Arch"].get("class_num", None)
|
||||
engine.config["DataLoader"].update({"class_num": class_num})
|
||||
engine.config["DataLoader"].update({
|
||||
"epochs": engine.config["Global"]["epochs"]
|
||||
})
|
||||
|
||||
use_dali = config["Global"].get("use_dali", False)
|
||||
use_dali = engine.use_dali
|
||||
dataloader_dict = {
|
||||
"Train": None,
|
||||
"UnLabelTrain": None,
|
||||
|
@ -243,37 +245,37 @@ def build_dataloader(config, mode):
|
|||
"Gallery": None,
|
||||
"GalleryQuery": None
|
||||
}
|
||||
if mode == 'train':
|
||||
if engine.mode == 'train':
|
||||
train_dataloader = build(
|
||||
config["DataLoader"], "Train", use_dali, seed=None)
|
||||
engine.config["DataLoader"], "Train", use_dali, seed=None)
|
||||
|
||||
if config["DataLoader"]["Train"].get("max_iter", None):
|
||||
if engine.config["DataLoader"]["Train"].get("max_iter", None):
|
||||
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
|
||||
max_iter = config["Train"].get("max_iter")
|
||||
update_freq = config["Global"].get("update_freq", 1)
|
||||
max_iter = train_dataloader.max_iter // update_freq * update_freq
|
||||
max_iter = engine.config["Train"].get("max_iter")
|
||||
max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq
|
||||
train_dataloader.max_iter = max_iter
|
||||
if config["DataLoader"]["Train"].get("convert_iterator", True):
|
||||
if engine.config["DataLoader"]["Train"].get("convert_iterator", True):
|
||||
train_dataloader = DataIterator(train_dataloader, use_dali)
|
||||
dataloader_dict["Train"] = train_dataloader
|
||||
|
||||
if config["DataLoader"].get('UnLabelTrain', None) is not None:
|
||||
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
|
||||
dataloader_dict["UnLabelTrain"] = build(
|
||||
config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
|
||||
engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
|
||||
|
||||
if mode == "eval" or (mode == "train" and
|
||||
config["Global"]["eval_during_train"]):
|
||||
if config["Global"]["eval_mode"] in ["classification", "adaface"]:
|
||||
if engine.mode == "eval" or (engine.mode == "train" and
|
||||
engine.config["Global"]["eval_during_train"]):
|
||||
if engine.config["Global"][
|
||||
"eval_mode"] in ["classification", "adaface"]:
|
||||
dataloader_dict["Eval"] = build(
|
||||
config["DataLoader"], "Eval", use_dali, seed=None)
|
||||
elif config["Global"]["eval_mode"] == "retrieval":
|
||||
if len(config["DataLoader"]["Eval"].keys()) == 1:
|
||||
key = list(config["DataLoader"]["Eval"].keys())[0]
|
||||
engine.config["DataLoader"], "Eval", use_dali, seed=None)
|
||||
elif engine.config["Global"]["eval_mode"] == "retrieval":
|
||||
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
|
||||
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
|
||||
dataloader_dict["GalleryQuery"] = build(
|
||||
config["DataLoader"]["Eval"], key, use_dali)
|
||||
engine.config["DataLoader"]["Eval"], key, use_dali)
|
||||
else:
|
||||
dataloader_dict["Gallery"] = build(
|
||||
config["DataLoader"]["Eval"], "Gallery", use_dali)
|
||||
dataloader_dict["Query"] = build(config["DataLoader"]["Eval"],
|
||||
"Query", use_dali)
|
||||
engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
|
||||
dataloader_dict["Query"] = build(
|
||||
engine.config["DataLoader"]["Eval"], "Query", use_dali)
|
||||
return dataloader_dict
|
||||
|
|
|
@ -76,7 +76,7 @@ class Engine(object):
|
|||
|
||||
# build dataloader
|
||||
self.use_dali = self.config["Global"].get("use_dali", False)
|
||||
self.dataloader_dict = build_dataloader(self.config, mode)
|
||||
self.dataloader_dict = build_dataloader(self)
|
||||
|
||||
# build loss
|
||||
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
|
||||
|
|
Loading…
Reference in New Issue