# Copyright (c) OpenMMLab. All rights reserved. import argparse import os import numpy as np import torch from mmengine.model import Sequential from tensorflow.python.training import py_checkpoint_reader from mmpretrain.models.backbones.efficientnet import EfficientNet def tf2pth(v): if v.ndim == 4: return np.ascontiguousarray(v.transpose(3, 2, 0, 1)) elif v.ndim == 2: return np.ascontiguousarray(v.transpose()) return v def read_ckpt(ckpt): reader = py_checkpoint_reader.NewCheckpointReader(ckpt) weights = { n: torch.as_tensor(tf2pth(reader.get_tensor(n))) for (n, _) in reader.get_variable_to_shape_map().items() } return weights def map_key(weight, l2_flag): m = dict() has_expand_conv = set() is_MBConv = set() max_idx = 0 name = None for k, v in weight.items(): seg = k.split('/') if len(seg) == 1: continue if 'edgetpu' in seg[0]: name = 'e' + seg[0][21:].lower() else: name = seg[0][13:] if seg[2] == 'tpu_batch_normalization_2': has_expand_conv.add(seg[1]) if seg[1].startswith('blocks_'): idx = int(seg[1][7:]) + 1 max_idx = max(max_idx, idx) if 'depthwise' in k: is_MBConv.add(seg[1]) model = EfficientNet(name) idx2key = [] for idx, module in enumerate(model.layers): if isinstance(module, Sequential): for j in range(len(module)): idx2key.append('{}.{}'.format(idx, j)) else: idx2key.append('{}'.format(idx)) for k, v in weight.items(): if l2_flag: k = k.replace('/ExponentialMovingAverage', '') if 'Exponential' in k or 'RMS' in k: continue seg = k.split('/') if len(seg) == 1: continue if seg[2] == 'depthwise_conv2d': v = v.transpose(1, 0) if seg[1] == 'stem': prefix = 'backbone.layers.{}'.format(idx2key[0]) mapping = { 'conv2d/kernel': 'conv.weight', 'tpu_batch_normalization/beta': 'bn.bias', 'tpu_batch_normalization/gamma': 'bn.weight', 'tpu_batch_normalization/moving_mean': 'bn.running_mean', 'tpu_batch_normalization/moving_variance': 'bn.running_var', } suffix = mapping['/'.join(seg[2:])] m[prefix + '.' + suffix] = v elif seg[1].startswith('blocks_'): idx = int(seg[1][7:]) + 1 prefix = '.'.join(['backbone', 'layers', idx2key[idx]]) if seg[1] not in is_MBConv: mapping = { 'conv2d/kernel': 'conv1.conv.weight', 'tpu_batch_normalization/gamma': 'conv1.bn.weight', 'tpu_batch_normalization/beta': 'conv1.bn.bias', 'tpu_batch_normalization/moving_mean': 'conv1.bn.running_mean', 'tpu_batch_normalization/moving_variance': 'conv1.bn.running_var', 'conv2d_1/kernel': 'conv2.conv.weight', 'tpu_batch_normalization_1/gamma': 'conv2.bn.weight', 'tpu_batch_normalization_1/beta': 'conv2.bn.bias', 'tpu_batch_normalization_1/moving_mean': 'conv2.bn.running_mean', 'tpu_batch_normalization_1/moving_variance': 'conv2.bn.running_var', } else: base_mapping = { 'depthwise_conv2d/depthwise_kernel': 'depthwise_conv.conv.weight', 'se/conv2d/kernel': 'se.conv1.conv.weight', 'se/conv2d/bias': 'se.conv1.conv.bias', 'se/conv2d_1/kernel': 'se.conv2.conv.weight', 'se/conv2d_1/bias': 'se.conv2.conv.bias' } if seg[1] not in has_expand_conv: mapping = { 'conv2d/kernel': 'linear_conv.conv.weight', 'tpu_batch_normalization/beta': 'depthwise_conv.bn.bias', 'tpu_batch_normalization/gamma': 'depthwise_conv.bn.weight', 'tpu_batch_normalization/moving_mean': 'depthwise_conv.bn.running_mean', 'tpu_batch_normalization/moving_variance': 'depthwise_conv.bn.running_var', 'tpu_batch_normalization_1/beta': 'linear_conv.bn.bias', 'tpu_batch_normalization_1/gamma': 'linear_conv.bn.weight', 'tpu_batch_normalization_1/moving_mean': 'linear_conv.bn.running_mean', 'tpu_batch_normalization_1/moving_variance': 'linear_conv.bn.running_var', } else: mapping = { 'depthwise_conv2d/depthwise_kernel': 'depthwise_conv.conv.weight', 'conv2d/kernel': 'expand_conv.conv.weight', 'conv2d_1/kernel': 'linear_conv.conv.weight', 'tpu_batch_normalization/beta': 'expand_conv.bn.bias', 'tpu_batch_normalization/gamma': 'expand_conv.bn.weight', 'tpu_batch_normalization/moving_mean': 'expand_conv.bn.running_mean', 'tpu_batch_normalization/moving_variance': 'expand_conv.bn.running_var', 'tpu_batch_normalization_1/beta': 'depthwise_conv.bn.bias', 'tpu_batch_normalization_1/gamma': 'depthwise_conv.bn.weight', 'tpu_batch_normalization_1/moving_mean': 'depthwise_conv.bn.running_mean', 'tpu_batch_normalization_1/moving_variance': 'depthwise_conv.bn.running_var', 'tpu_batch_normalization_2/beta': 'linear_conv.bn.bias', 'tpu_batch_normalization_2/gamma': 'linear_conv.bn.weight', 'tpu_batch_normalization_2/moving_mean': 'linear_conv.bn.running_mean', 'tpu_batch_normalization_2/moving_variance': 'linear_conv.bn.running_var', } mapping.update(base_mapping) suffix = mapping['/'.join(seg[2:])] m[prefix + '.' + suffix] = v elif seg[1] == 'head': seq_key = idx2key[max_idx + 1] mapping = { 'conv2d/kernel': 'backbone.layers.{}.conv.weight'.format(seq_key), 'tpu_batch_normalization/beta': 'backbone.layers.{}.bn.bias'.format(seq_key), 'tpu_batch_normalization/gamma': 'backbone.layers.{}.bn.weight'.format(seq_key), 'tpu_batch_normalization/moving_mean': 'backbone.layers.{}.bn.running_mean'.format(seq_key), 'tpu_batch_normalization/moving_variance': 'backbone.layers.{}.bn.running_var'.format(seq_key), 'dense/kernel': 'head.fc.weight', 'dense/bias': 'head.fc.bias' } key = mapping['/'.join(seg[2:])] if name.startswith('e') and 'fc' in key: v = v[1:] m[key] = v return m if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('infile', type=str, help='Path to the ckpt.') parser.add_argument('outfile', type=str, help='Output file.') parser.add_argument( '--l2', action='store_true', help='If true convert ExponentialMovingAverage weights. ' 'l2 arch should use it.') args = parser.parse_args() assert args.outfile outdir = os.path.dirname(os.path.abspath(args.outfile)) if not os.path.exists(outdir): os.makedirs(outdir) weights = read_ckpt(args.infile) weights = map_key(weights, args.l2) torch.save(weights, args.outfile)