[Enhancement] Replace MMCV with MMEngine in convert model scripts (#1798)

* replace mmcv with mmengine

* remove transforms
This commit is contained in:
谢昕辰 2022-07-27 17:54:37 +08:00 committed by GitHub
parent 6873f9ece8
commit 1ded0a4278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 20 additions and 20 deletions

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv import mmengine
import torch import torch
from mmcv.runner import CheckpointLoader from mmengine.runner import CheckpointLoader
def convert_beit(ckpt): def convert_beit(ckpt):
@ -48,7 +48,7 @@ def main():
else: else:
state_dict = checkpoint state_dict = checkpoint
weight = convert_beit(state_dict) weight = convert_beit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst) torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv import mmengine
import torch import torch
from mmcv.runner import CheckpointLoader from mmengine.runner import CheckpointLoader
def convert_mit(ckpt): def convert_mit(ckpt):
@ -74,7 +74,7 @@ def main():
else: else:
state_dict = checkpoint state_dict = checkpoint
weight = convert_mit(state_dict) weight = convert_mit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst) torch.save(weight, args.dst)

View File

@ -2,9 +2,9 @@
import argparse import argparse
import os.path as osp import os.path as osp
import mmcv import mmengine
import torch import torch
from mmcv.runner import CheckpointLoader from mmengine.runner import CheckpointLoader
def convert_stdc(ckpt, stdc_type): def convert_stdc(ckpt, stdc_type):
@ -63,7 +63,7 @@ def main():
assert args.type in ['STDC1', assert args.type in ['STDC1',
'STDC2'], 'STD type should be STDC1 or STDC2!' 'STDC2'], 'STD type should be STDC1 or STDC2!'
weight = convert_stdc(state_dict, args.type) weight = convert_stdc(state_dict, args.type)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst) torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv import mmengine
import torch import torch
from mmcv.runner import CheckpointLoader from mmengine.runner import CheckpointLoader
def convert_swin(ckpt): def convert_swin(ckpt):
@ -79,7 +79,7 @@ def main():
else: else:
state_dict = checkpoint state_dict = checkpoint
weight = convert_swin(state_dict) weight = convert_swin(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst) torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv import mmengine
import torch import torch
from mmcv.runner import CheckpointLoader from mmengine.runner import CheckpointLoader
def convert_twins(args, ckpt): def convert_twins(args, ckpt):
@ -79,7 +79,7 @@ def main():
state_dict = checkpoint state_dict = checkpoint
weight = convert_twins(args, state_dict) weight = convert_twins(args, state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst) torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
import mmcv import mmengine
import torch import torch
from mmcv.runner import CheckpointLoader from mmengine.runner import CheckpointLoader
def convert_vit(ckpt): def convert_vit(ckpt):
@ -62,7 +62,7 @@ def main():
else: else:
state_dict = checkpoint state_dict = checkpoint
weight = convert_vit(state_dict) weight = convert_vit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst) torch.save(weight, args.dst)

View File

@ -2,7 +2,7 @@
import argparse import argparse
import os.path as osp import os.path as osp
import mmcv import mmengine
import numpy as np import numpy as np
import torch import torch
@ -115,7 +115,7 @@ def main():
else: else:
num_layer = 12 num_layer = 12
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer) torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
mmcv.mkdir_or_exist(osp.dirname(args.dst)) mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(torch_weights, args.dst) torch.save(torch_weights, args.dst)