[Enhancement] Replace MMCV with MMEngine in convert model scripts (#1798)

* replace mmcv with mmengine

* remove transforms
pull/1835/head
谢昕辰 2022-07-27 17:54:37 +08:00 committed by GitHub
parent 6873f9ece8
commit 1ded0a4278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 20 additions and 20 deletions

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import mmengine
import torch
from mmcv.runner import CheckpointLoader
from mmengine.runner import CheckpointLoader
def convert_beit(ckpt):
@ -48,7 +48,7 @@ def main():
else:
state_dict = checkpoint
weight = convert_beit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import mmengine
import torch
from mmcv.runner import CheckpointLoader
from mmengine.runner import CheckpointLoader
def convert_mit(ckpt):
@ -74,7 +74,7 @@ def main():
else:
state_dict = checkpoint
weight = convert_mit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

View File

@ -2,9 +2,9 @@
import argparse
import os.path as osp
import mmcv
import mmengine
import torch
from mmcv.runner import CheckpointLoader
from mmengine.runner import CheckpointLoader
def convert_stdc(ckpt, stdc_type):
@ -63,7 +63,7 @@ def main():
assert args.type in ['STDC1',
'STDC2'], 'STD type should be STDC1 or STDC2!'
weight = convert_stdc(state_dict, args.type)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import mmengine
import torch
from mmcv.runner import CheckpointLoader
from mmengine.runner import CheckpointLoader
def convert_swin(ckpt):
@ -79,7 +79,7 @@ def main():
else:
state_dict = checkpoint
weight = convert_swin(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import mmengine
import torch
from mmcv.runner import CheckpointLoader
from mmengine.runner import CheckpointLoader
def convert_twins(args, ckpt):
@ -79,7 +79,7 @@ def main():
state_dict = checkpoint
weight = convert_twins(args, state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

View File

@ -3,9 +3,9 @@ import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import mmengine
import torch
from mmcv.runner import CheckpointLoader
from mmengine.runner import CheckpointLoader
def convert_vit(ckpt):
@ -62,7 +62,7 @@ def main():
else:
state_dict = checkpoint
weight = convert_vit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

View File

@ -2,7 +2,7 @@
import argparse
import os.path as osp
import mmcv
import mmengine
import numpy as np
import torch
@ -115,7 +115,7 @@ def main():
else:
num_layer = 12
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(torch_weights, args.dst)