[Enhancement] Replace MMCV with MMEngine in convert model scripts (#1798)
* replace mmcv with mmengine * remove transformspull/1835/head
parent
6873f9ece8
commit
1ded0a4278
|
@ -3,9 +3,9 @@ import argparse
|
|||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_beit(ckpt):
|
||||
|
@ -48,7 +48,7 @@ def main():
|
|||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_beit(state_dict)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ import argparse
|
|||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_mit(ckpt):
|
||||
|
@ -74,7 +74,7 @@ def main():
|
|||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_mit(state_dict)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_stdc(ckpt, stdc_type):
|
||||
|
@ -63,7 +63,7 @@ def main():
|
|||
assert args.type in ['STDC1',
|
||||
'STDC2'], 'STD type should be STDC1 or STDC2!'
|
||||
weight = convert_stdc(state_dict, args.type)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ import argparse
|
|||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_swin(ckpt):
|
||||
|
@ -79,7 +79,7 @@ def main():
|
|||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_swin(state_dict)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ import argparse
|
|||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_twins(args, ckpt):
|
||||
|
@ -79,7 +79,7 @@ def main():
|
|||
state_dict = checkpoint
|
||||
|
||||
weight = convert_twins(args, state_dict)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ import argparse
|
|||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_vit(ckpt):
|
||||
|
@ -62,7 +62,7 @@ def main():
|
|||
else:
|
||||
state_dict = checkpoint
|
||||
weight = convert_vit(state_dict)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -115,7 +115,7 @@ def main():
|
|||
else:
|
||||
num_layer = 12
|
||||
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(torch_weights, args.dst)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue