2021-08-22 15:10:23 +00:00
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-02-22 07:28:28 +00:00
|
|
|
|
|
|
|
from .train_metabin import train_epoch_metabin
|
2023-03-05 12:53:36 +00:00
|
|
|
from .classification import ClassTrainer
|
2023-02-22 07:28:28 +00:00
|
|
|
from .train_fixmatch import train_epoch_fixmatch
|
|
|
|
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
|
|
|
|
from .train_progressive import train_epoch_progressive
|
|
|
|
|
|
|
|
|
2023-03-05 12:53:36 +00:00
|
|
|
def build_train_func(config, mode, model, eval_func):
|
|
|
|
if mode != "train":
|
|
|
|
return None
|
2023-03-09 12:23:32 +00:00
|
|
|
task = config["Global"].get("task", "classification")
|
|
|
|
if task == "classification" or task == "retrieval":
|
2023-03-07 11:38:03 +00:00
|
|
|
return ClassTrainer(config, model, eval_func)
|
2023-02-22 07:28:28 +00:00
|
|
|
else:
|
2023-03-05 12:53:36 +00:00
|
|
|
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
|
|
|
|
config, mode, model, eval_func)
|