diff --git a/train.py b/train.py index 1401ccb96..c3e7172c3 100644 --- a/train.py +++ b/train.py @@ -99,6 +99,74 @@ RANK = int(os.getenv("RANK", -1)) WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) GIT_INFO = check_git_info() +import torch._dynamo +from compile.FxGraphConvertor import fx2mlir +from torch._functorch.aot_autograd import aot_export_module +from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend + + +class JitNet(nn.Module): + def __init__(self, net, loss_fn): + super().__init__() + self.net = net + self.loss_fn = loss_fn + + def forward(self, x, y): + self.net(x) + loss, loss_item = self.loss_fn(self.net(x), y) + return loss, loss_item.detach() + + +def _get_disc_decomp(): + from torch._decomp import get_decompositions + + aten = torch.ops.aten + decompositions_dict = get_decompositions( + [ + aten.gelu, + aten.gelu_backward, + aten.native_group_norm_backward, + # aten.native_layer_norm, + aten.native_layer_norm_backward, + # aten.std_mean.correction, + # aten._softmax, + aten._softmax_backward_data, + aten.tanh_backward, + aten.slice_backward, + aten.select_backward, + aten.embedding_dense_backward, + aten.sigmoid_backward, + aten.nll_loss_backward, + aten._log_softmax_backward_data, + aten.nll_loss_forward, + ] + ) + return decompositions_dict + + +def convert_module_fx( + submodule_name: str, + module: torch.fx.GraphModule, + args={}, + bwd_graph: bool = False, + para_shape: list = [], +): + c = fx2mlir(submodule_name, args, bwd_graph, para_shape) + return c.convert(module) + + +class SophonJointCompile: + def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None): + fx_g, signature = aot_export_module( + model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp() + ) + fx_g.to_folder("yolov5sc", "joint") + breakpoint() + + def fx_convert_bmodel(self): + name = f"test_{args.model}_joint_{args.batch}" + convert_module_fx(name, self.fx_g, self.args, False) + def train(hyp, opt, device, callbacks): """ @@ -384,6 +452,8 @@ def train(hyp, opt, device, callbacks): if RANK in {-1, 0}: pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar optimizer.zero_grad() + model_opt = torch.compile(model, backend=aot_backend) + # zwyjit = JitNet(model, compute_loss) for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- callbacks.run("on_train_batch_start") ni = i + nb * epoch # number integrated batches (since train start) @@ -410,8 +480,13 @@ def train(hyp, opt, device, callbacks): # Forward with torch.cuda.amp.autocast(amp): - pred = model(imgs) # forward + pred = model_opt(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size + # fx_g, signature = aot_export_module( + # model, [imgs], trace_joint=False, output_loss_index=0, decompositions=_get_disc_decomp() + # ) + # print('fx_g:', fx_g) + if RANK != -1: loss *= WORLD_SIZE # gradient averaged between devices in DDP mode if opt.quad: @@ -983,4 +1058,19 @@ def run(**kwargs): if __name__ == "__main__": opt = parse_opt() + + parser = argparse.ArgumentParser() + parser.add_argument("--chip", default="bm1690", choices=["bm1684x", "bm1690", "sg2260"], help="chip name") + parser.add_argument("--debug", default="print_ori_fx_graph", help="debug") + parser.add_argument("--cmp", action="store_true", help="enable cmp") + parser.add_argument("--fast_test", action="store_true", help="fast_test") + parser.add_argument("--skip_module_num", default=0, type=int, help="skip_module_num") + parser.add_argument("--exit_at", default=-1, type=int, help="exit_at") + parser.add_argument("--num_core", default=1, type=int, help="The number of TPU cores used for parallel computation") + parser.add_argument("--opt", default=2, type=int, help="layer group opt") + parser.add_argument("--fp", default="", help="fp") + import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit + + tpu_mlir_jit.args = parser.parse_known_args()[0] + main(opt)