Auto-format by https://ultralytics.com
parent
4fb0c63f54
commit
42c4ae3350
56
train.py
56
train.py
|
@ -99,15 +99,11 @@ RANK = int(os.getenv("RANK", -1))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||||
GIT_INFO = check_git_info()
|
GIT_INFO = check_git_info()
|
||||||
|
|
||||||
from tpu_mlir.python.tools.train.tpu_mlir_jit import device, aot_backend
|
|
||||||
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
|
||||||
import torch.optim as optim
|
|
||||||
from compile.FxGraphConvertor import fx2mlir
|
|
||||||
import torchvision.models as models
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
from torch.fx import Interpreter
|
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
from compile.FxGraphConvertor import fx2mlir
|
||||||
|
from torch._functorch.aot_autograd import aot_export_module
|
||||||
|
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend
|
||||||
|
|
||||||
|
|
||||||
class JitNet(nn.Module):
|
class JitNet(nn.Module):
|
||||||
def __init__(self, net, loss_fn):
|
def __init__(self, net, loss_fn):
|
||||||
|
@ -116,12 +112,14 @@ class JitNet(nn.Module):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
predict = self.net(x)
|
self.net(x)
|
||||||
loss,loss_item = self.loss_fn(self.net(x), y)
|
loss, loss_item = self.loss_fn(self.net(x), y)
|
||||||
return loss, loss_item.detach()
|
return loss, loss_item.detach()
|
||||||
|
|
||||||
|
|
||||||
def _get_disc_decomp():
|
def _get_disc_decomp():
|
||||||
from torch._decomp import get_decompositions
|
from torch._decomp import get_decompositions
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
decompositions_dict = get_decompositions(
|
decompositions_dict = get_decompositions(
|
||||||
[
|
[
|
||||||
|
@ -150,19 +148,19 @@ def convert_module_fx(
|
||||||
submodule_name: str,
|
submodule_name: str,
|
||||||
module: torch.fx.GraphModule,
|
module: torch.fx.GraphModule,
|
||||||
args={},
|
args={},
|
||||||
bwd_graph:bool=False,
|
bwd_graph: bool = False,
|
||||||
para_shape: list=[],
|
para_shape: list = [],
|
||||||
) :
|
):
|
||||||
c = fx2mlir(submodule_name, args, bwd_graph, para_shape)
|
c = fx2mlir(submodule_name, args, bwd_graph, para_shape)
|
||||||
return c.convert(module)
|
return c.convert(module)
|
||||||
|
|
||||||
class SophonJointCompile:
|
|
||||||
|
|
||||||
|
class SophonJointCompile:
|
||||||
def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None):
|
def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None):
|
||||||
fx_g, signature = aot_export_module(
|
fx_g, signature = aot_export_module(
|
||||||
model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp()
|
model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp()
|
||||||
)
|
)
|
||||||
fx_g.to_folder("yolov5sc","joint")
|
fx_g.to_folder("yolov5sc", "joint")
|
||||||
breakpoint()
|
breakpoint()
|
||||||
|
|
||||||
def fx_convert_bmodel(self):
|
def fx_convert_bmodel(self):
|
||||||
|
@ -170,7 +168,6 @@ class SophonJointCompile:
|
||||||
convert_module_fx(name, self.fx_g, self.args, False)
|
convert_module_fx(name, self.fx_g, self.args, False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(hyp, opt, device, callbacks):
|
def train(hyp, opt, device, callbacks):
|
||||||
"""
|
"""
|
||||||
Train a YOLOv5 model on a custom dataset using specified hyperparameters, options, and device, managing datasets,
|
Train a YOLOv5 model on a custom dataset using specified hyperparameters, options, and device, managing datasets,
|
||||||
|
@ -1063,24 +1060,17 @@ if __name__ == "__main__":
|
||||||
opt = parse_opt()
|
opt = parse_opt()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--chip", default="bm1690", choices=['bm1684x', 'bm1690','sg2260'],
|
parser.add_argument("--chip", default="bm1690", choices=["bm1684x", "bm1690", "sg2260"], help="chip name")
|
||||||
help="chip name")
|
parser.add_argument("--debug", default="print_ori_fx_graph", help="debug")
|
||||||
parser.add_argument("--debug", default="print_ori_fx_graph",
|
parser.add_argument("--cmp", action="store_true", help="enable cmp")
|
||||||
help="debug")
|
parser.add_argument("--fast_test", action="store_true", help="fast_test")
|
||||||
parser.add_argument("--cmp", action='store_true',
|
parser.add_argument("--skip_module_num", default=0, type=int, help="skip_module_num")
|
||||||
help="enable cmp")
|
parser.add_argument("--exit_at", default=-1, type=int, help="exit_at")
|
||||||
parser.add_argument("--fast_test", action='store_true',
|
parser.add_argument("--num_core", default=1, type=int, help="The number of TPU cores used for parallel computation")
|
||||||
help="fast_test")
|
parser.add_argument("--opt", default=2, type=int, help="layer group opt")
|
||||||
parser.add_argument("--skip_module_num", default=0, type=int,
|
parser.add_argument("--fp", default="", help="fp")
|
||||||
help='skip_module_num')
|
|
||||||
parser.add_argument("--exit_at", default=-1, type=int,
|
|
||||||
help='exit_at')
|
|
||||||
parser.add_argument("--num_core", default=1, type=int,
|
|
||||||
help='The numer of TPU cores used for parallel computation')
|
|
||||||
parser.add_argument("--opt", default=2, type=int,
|
|
||||||
help='layer group opt')
|
|
||||||
parser.add_argument("--fp", default="",help="fp")
|
|
||||||
import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit
|
import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit
|
||||||
|
|
||||||
tpu_mlir_jit.args = parser.parse_known_args()[0]
|
tpu_mlir_jit.args = parser.parse_known_args()[0]
|
||||||
|
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|
Loading…
Reference in New Issue