pull/13510/merge
wangxc2006 2025-04-18 01:02:51 +02:00 committed by GitHub
commit 1a5504c41b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 91 additions and 1 deletions

View File

@ -99,6 +99,74 @@ RANK = int(os.getenv("RANK", -1))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
GIT_INFO = check_git_info()
import torch._dynamo
from compile.FxGraphConvertor import fx2mlir
from torch._functorch.aot_autograd import aot_export_module
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend
class JitNet(nn.Module):
def __init__(self, net, loss_fn):
super().__init__()
self.net = net
self.loss_fn = loss_fn
def forward(self, x, y):
self.net(x)
loss, loss_item = self.loss_fn(self.net(x), y)
return loss, loss_item.detach()
def _get_disc_decomp():
from torch._decomp import get_decompositions
aten = torch.ops.aten
decompositions_dict = get_decompositions(
[
aten.gelu,
aten.gelu_backward,
aten.native_group_norm_backward,
# aten.native_layer_norm,
aten.native_layer_norm_backward,
# aten.std_mean.correction,
# aten._softmax,
aten._softmax_backward_data,
aten.tanh_backward,
aten.slice_backward,
aten.select_backward,
aten.embedding_dense_backward,
aten.sigmoid_backward,
aten.nll_loss_backward,
aten._log_softmax_backward_data,
aten.nll_loss_forward,
]
)
return decompositions_dict
def convert_module_fx(
submodule_name: str,
module: torch.fx.GraphModule,
args={},
bwd_graph: bool = False,
para_shape: list = [],
):
c = fx2mlir(submodule_name, args, bwd_graph, para_shape)
return c.convert(module)
class SophonJointCompile:
def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None):
fx_g, signature = aot_export_module(
model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp()
)
fx_g.to_folder("yolov5sc", "joint")
breakpoint()
def fx_convert_bmodel(self):
name = f"test_{args.model}_joint_{args.batch}"
convert_module_fx(name, self.fx_g, self.args, False)
def train(hyp, opt, device, callbacks):
"""
@ -384,6 +452,8 @@ def train(hyp, opt, device, callbacks):
if RANK in {-1, 0}:
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
optimizer.zero_grad()
model_opt = torch.compile(model, backend=aot_backend)
# zwyjit = JitNet(model, compute_loss)
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
callbacks.run("on_train_batch_start")
ni = i + nb * epoch # number integrated batches (since train start)
@ -410,8 +480,13 @@ def train(hyp, opt, device, callbacks):
# Forward
with torch.cuda.amp.autocast(amp):
pred = model(imgs) # forward
pred = model_opt(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
# fx_g, signature = aot_export_module(
# model, [imgs], trace_joint=False, output_loss_index=0, decompositions=_get_disc_decomp()
# )
# print('fx_g:', fx_g)
if RANK != -1:
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad:
@ -983,4 +1058,19 @@ def run(**kwargs):
if __name__ == "__main__":
opt = parse_opt()
parser = argparse.ArgumentParser()
parser.add_argument("--chip", default="bm1690", choices=["bm1684x", "bm1690", "sg2260"], help="chip name")
parser.add_argument("--debug", default="print_ori_fx_graph", help="debug")
parser.add_argument("--cmp", action="store_true", help="enable cmp")
parser.add_argument("--fast_test", action="store_true", help="fast_test")
parser.add_argument("--skip_module_num", default=0, type=int, help="skip_module_num")
parser.add_argument("--exit_at", default=-1, type=int, help="exit_at")
parser.add_argument("--num_core", default=1, type=int, help="The number of TPU cores used for parallel computation")
parser.add_argument("--opt", default=2, type=int, help="layer group opt")
parser.add_argument("--fp", default="", help="fp")
import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit
tpu_mlir_jit.args = parser.parse_known_args()[0]
main(opt)