commit
dfb8e26923
|
@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer):
|
|||
drop_path_rate=0.,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-5,
|
||||
**args):
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.class_num = class_num
|
||||
|
||||
|
@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
|
|||
)
|
||||
|
||||
|
||||
def ViT_small_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_small_patch16_224(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_base_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_base_patch16_224(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_base_patch16_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_base_patch16_384(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_base_patch32_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_base_patch32_384(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_large_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_large_patch16_224(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_large_patch16_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_large_patch16_384(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_large_patch32_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_large_patch32_384(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_huge_patch16_224(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_huge_patch16_224(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained,
|
|||
return model
|
||||
|
||||
|
||||
def ViT_huge_patch32_384(pretrained,
|
||||
model,
|
||||
model_url,
|
||||
def ViT_huge_patch32_384(pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
model = VisionTransformer(
|
||||
|
|
|
@ -574,6 +574,8 @@ class Trainer(object):
|
|||
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
|
||||
batch_tensor = paddle.to_tensor(batch_data)
|
||||
out = self.model(batch_tensor)
|
||||
if isinstance(out, list):
|
||||
out = out[0]
|
||||
result = postprocess_func(out, image_file_list)
|
||||
print(result)
|
||||
batch_data.clear()
|
||||
|
|
|
@ -63,6 +63,8 @@ class ExportModel(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
if self.infer_model_name is not None:
|
||||
x = x[self.infer_model_name]
|
||||
if self.infer_output_key is not None:
|
||||
|
@ -76,7 +78,6 @@ if __name__ == "__main__":
|
|||
args = config.parse_args()
|
||||
config = config.get_config(
|
||||
args.config, overrides=args.override, show=False)
|
||||
|
||||
log_file = os.path.join(config['Global']['output_dir'],
|
||||
config["Arch"]["name"], "export.log")
|
||||
init_logger(name='root', log_file=log_file)
|
||||
|
@ -86,7 +87,6 @@ if __name__ == "__main__":
|
|||
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
device = paddle.set_device(config["Global"]["device"])
|
||||
model = ExportModel(config["Arch"])
|
||||
|
||||
if config["Global"]["pretrained_model"] is not None:
|
||||
load_dygraph_pretrain(model.base_model,
|
||||
config["Global"]["pretrained_model"])
|
||||
|
|
Loading…
Reference in New Issue