Merge 42c4ae3350
into fe1d4d9947
commit
1a5504c41b
92
train.py
92
train.py
|
@ -99,6 +99,74 @@ RANK = int(os.getenv("RANK", -1))
|
|||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||
GIT_INFO = check_git_info()
|
||||
|
||||
import torch._dynamo
|
||||
from compile.FxGraphConvertor import fx2mlir
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend
|
||||
|
||||
|
||||
class JitNet(nn.Module):
|
||||
def __init__(self, net, loss_fn):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
def forward(self, x, y):
|
||||
self.net(x)
|
||||
loss, loss_item = self.loss_fn(self.net(x), y)
|
||||
return loss, loss_item.detach()
|
||||
|
||||
|
||||
def _get_disc_decomp():
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
aten = torch.ops.aten
|
||||
decompositions_dict = get_decompositions(
|
||||
[
|
||||
aten.gelu,
|
||||
aten.gelu_backward,
|
||||
aten.native_group_norm_backward,
|
||||
# aten.native_layer_norm,
|
||||
aten.native_layer_norm_backward,
|
||||
# aten.std_mean.correction,
|
||||
# aten._softmax,
|
||||
aten._softmax_backward_data,
|
||||
aten.tanh_backward,
|
||||
aten.slice_backward,
|
||||
aten.select_backward,
|
||||
aten.embedding_dense_backward,
|
||||
aten.sigmoid_backward,
|
||||
aten.nll_loss_backward,
|
||||
aten._log_softmax_backward_data,
|
||||
aten.nll_loss_forward,
|
||||
]
|
||||
)
|
||||
return decompositions_dict
|
||||
|
||||
|
||||
def convert_module_fx(
|
||||
submodule_name: str,
|
||||
module: torch.fx.GraphModule,
|
||||
args={},
|
||||
bwd_graph: bool = False,
|
||||
para_shape: list = [],
|
||||
):
|
||||
c = fx2mlir(submodule_name, args, bwd_graph, para_shape)
|
||||
return c.convert(module)
|
||||
|
||||
|
||||
class SophonJointCompile:
|
||||
def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None):
|
||||
fx_g, signature = aot_export_module(
|
||||
model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp()
|
||||
)
|
||||
fx_g.to_folder("yolov5sc", "joint")
|
||||
breakpoint()
|
||||
|
||||
def fx_convert_bmodel(self):
|
||||
name = f"test_{args.model}_joint_{args.batch}"
|
||||
convert_module_fx(name, self.fx_g, self.args, False)
|
||||
|
||||
|
||||
def train(hyp, opt, device, callbacks):
|
||||
"""
|
||||
|
@ -384,6 +452,8 @@ def train(hyp, opt, device, callbacks):
|
|||
if RANK in {-1, 0}:
|
||||
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
|
||||
optimizer.zero_grad()
|
||||
model_opt = torch.compile(model, backend=aot_backend)
|
||||
# zwyjit = JitNet(model, compute_loss)
|
||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||
callbacks.run("on_train_batch_start")
|
||||
ni = i + nb * epoch # number integrated batches (since train start)
|
||||
|
@ -410,8 +480,13 @@ def train(hyp, opt, device, callbacks):
|
|||
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
pred = model(imgs) # forward
|
||||
pred = model_opt(imgs) # forward
|
||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||
# fx_g, signature = aot_export_module(
|
||||
# model, [imgs], trace_joint=False, output_loss_index=0, decompositions=_get_disc_decomp()
|
||||
# )
|
||||
# print('fx_g:', fx_g)
|
||||
|
||||
if RANK != -1:
|
||||
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
|
||||
if opt.quad:
|
||||
|
@ -983,4 +1058,19 @@ def run(**kwargs):
|
|||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--chip", default="bm1690", choices=["bm1684x", "bm1690", "sg2260"], help="chip name")
|
||||
parser.add_argument("--debug", default="print_ori_fx_graph", help="debug")
|
||||
parser.add_argument("--cmp", action="store_true", help="enable cmp")
|
||||
parser.add_argument("--fast_test", action="store_true", help="fast_test")
|
||||
parser.add_argument("--skip_module_num", default=0, type=int, help="skip_module_num")
|
||||
parser.add_argument("--exit_at", default=-1, type=int, help="exit_at")
|
||||
parser.add_argument("--num_core", default=1, type=int, help="The number of TPU cores used for parallel computation")
|
||||
parser.add_argument("--opt", default=2, type=int, help="layer group opt")
|
||||
parser.add_argument("--fp", default="", help="fp")
|
||||
import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit
|
||||
|
||||
tpu_mlir_jit.args = parser.parse_known_args()[0]
|
||||
|
||||
main(opt)
|
||||
|
|
Loading…
Reference in New Issue