PaddleClas/ppcls/arch/__init__.py

65 lines
1.9 KiB
Python
Raw Normal View History

2021-06-04 16:44:24 +08:00
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2020-04-09 02:16:30 +08:00
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import copy
import importlib
import paddle.nn as nn
2021-06-04 16:44:24 +08:00
from . import backbone, gears
2021-05-24 11:42:24 +08:00
from .backbone import *
2021-06-04 14:54:34 +08:00
from .gears import build_gear
2020-04-13 18:53:03 +08:00
from .utils import *
2021-06-01 11:30:26 +08:00
__all__ = ["build_model", "RecModel"]
def build_model(config):
config = copy.deepcopy(config)
model_type = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config)
return arch
class RecModel(nn.Layer):
def __init__(self, **config):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
2021-06-04 16:44:24 +08:00
self.backbone = eval(backbone_name)(**backbone_config)
2021-06-04 14:54:34 +08:00
if "BackboneStopLayer" in config:
2021-06-04 16:44:24 +08:00
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
2021-06-02 20:04:24 +08:00
2021-06-04 14:54:34 +08:00
if "Neck" in config:
self.neck = build_gear(config["Neck"])
else:
self.neck = None
2021-06-02 20:04:24 +08:00
2021-06-04 14:54:34 +08:00
if "Head" in config:
self.head = build_gear(config["Head"])
else:
self.head = None
2021-06-05 17:56:40 +08:00
def forward(self, x, label=None):
2021-06-04 16:44:24 +08:00
x = self.backbone(x)
if self.neck is not None:
2021-06-04 16:44:24 +08:00
x = self.neck(x)
2021-06-04 14:54:34 +08:00
if self.head is not None:
2021-06-04 16:44:24 +08:00
y = self.head(x, label)
2021-06-05 17:45:31 +08:00
else:
y = None
2021-06-04 16:44:24 +08:00
return {"features": x, "logits": y}