mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
yolov5_train_sample
This commit is contained in:
parent
597ff168f4
commit
4fb0c63f54
42
train.py
42
train.py
@ -99,6 +99,7 @@ RANK = int(os.getenv("RANK", -1))
|
|||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||||
GIT_INFO = check_git_info()
|
GIT_INFO = check_git_info()
|
||||||
|
|
||||||
|
from tpu_mlir.python.tools.train.tpu_mlir_jit import device, aot_backend
|
||||||
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from compile.FxGraphConvertor import fx2mlir
|
from compile.FxGraphConvertor import fx2mlir
|
||||||
@ -454,7 +455,8 @@ def train(hyp, opt, device, callbacks):
|
|||||||
if RANK in {-1, 0}:
|
if RANK in {-1, 0}:
|
||||||
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
|
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
zwyjit = JitNet(model, compute_loss)
|
model_opt = torch.compile(model, backend=aot_backend)
|
||||||
|
# zwyjit = JitNet(model, compute_loss)
|
||||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||||
callbacks.run("on_train_batch_start")
|
callbacks.run("on_train_batch_start")
|
||||||
ni = i + nb * epoch # number integrated batches (since train start)
|
ni = i + nb * epoch # number integrated batches (since train start)
|
||||||
@ -481,15 +483,13 @@ def train(hyp, opt, device, callbacks):
|
|||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
with torch.cuda.amp.autocast(amp):
|
||||||
# print(1)
|
pred = model_opt(imgs) # forward
|
||||||
# zwy = SophonJointCompile(model, [imgs, targets], trace_joint=True, output_loss_index=0, args=None)
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||||
# pred = model(imgs) # forward
|
# fx_g, signature = aot_export_module(
|
||||||
# loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
# model, [imgs], trace_joint=False, output_loss_index=0, decompositions=_get_disc_decomp()
|
||||||
# loss, loss_items = zwyjit(imgs, targets.to(device))
|
# )
|
||||||
fx_g, signature = aot_export_module(
|
# print('fx_g:', fx_g)
|
||||||
zwyjit, [imgs, targets], trace_joint=True, output_loss_index=0, decompositions=_get_disc_decomp()
|
|
||||||
)
|
|
||||||
print(fx_g)
|
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
|
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
|
||||||
if opt.quad:
|
if opt.quad:
|
||||||
@ -1061,4 +1061,26 @@ def run(**kwargs):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
opt = parse_opt()
|
opt = parse_opt()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--chip", default="bm1690", choices=['bm1684x', 'bm1690','sg2260'],
|
||||||
|
help="chip name")
|
||||||
|
parser.add_argument("--debug", default="print_ori_fx_graph",
|
||||||
|
help="debug")
|
||||||
|
parser.add_argument("--cmp", action='store_true',
|
||||||
|
help="enable cmp")
|
||||||
|
parser.add_argument("--fast_test", action='store_true',
|
||||||
|
help="fast_test")
|
||||||
|
parser.add_argument("--skip_module_num", default=0, type=int,
|
||||||
|
help='skip_module_num')
|
||||||
|
parser.add_argument("--exit_at", default=-1, type=int,
|
||||||
|
help='exit_at')
|
||||||
|
parser.add_argument("--num_core", default=1, type=int,
|
||||||
|
help='The numer of TPU cores used for parallel computation')
|
||||||
|
parser.add_argument("--opt", default=2, type=int,
|
||||||
|
help='layer group opt')
|
||||||
|
parser.add_argument("--fp", default="",help="fp")
|
||||||
|
import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit
|
||||||
|
tpu_mlir_jit.args = parser.parse_known_args()[0]
|
||||||
|
|
||||||
main(opt)
|
main(opt)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user