Auto-format by https://ultralytics.com
parent
3b4d7777ce
commit
3ae26e695a
|
@ -333,7 +333,7 @@ class SPPF(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Processes input through a series of convolutions and max pooling operations for feature extraction."""
|
||||
x = self.cv1(x)
|
||||
# wangxuec: We need to comment this out, otherwise we'll end up with a very fragmented portion of the captured graph
|
||||
# wangxuec: We need to comment this out, otherwise we'll end up with a very fragmented portion of the captured graph
|
||||
# with warnings.catch_warnings():
|
||||
# warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning
|
||||
y1 = self.m(x)
|
||||
|
|
60
train.py
60
train.py
|
@ -13,7 +13,7 @@ Models: https://github.com/ultralytics/yolov5/tree/master/models
|
|||
Datasets: https://github.com/ultralytics/yolov5/tree/master/data
|
||||
Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
|
||||
"""
|
||||
import torch_tpu
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
@ -25,6 +25,8 @@ from copy import deepcopy
|
|||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import torch_tpu
|
||||
|
||||
try:
|
||||
import comet_ml # must be imported before torch (if installed)
|
||||
except ImportError:
|
||||
|
@ -100,10 +102,10 @@ WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
|||
GIT_INFO = check_git_info()
|
||||
|
||||
# from tpu_mlir import aot_backend
|
||||
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend
|
||||
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
||||
|
||||
import torch._dynamo.config
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend
|
||||
|
||||
|
||||
class JitNet(nn.Module):
|
||||
def __init__(self, net, loss_fn):
|
||||
|
@ -113,12 +115,13 @@ class JitNet(nn.Module):
|
|||
|
||||
def forward(self, x, y):
|
||||
predict = self.net(x)
|
||||
loss,loss_item = self.loss_fn(predict, y)
|
||||
loss, loss_item = self.loss_fn(predict, y)
|
||||
return loss, loss_item.detach()
|
||||
|
||||
|
||||
def _get_disc_decomp():
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
aten = torch.ops.aten
|
||||
decompositions_dict = get_decompositions(
|
||||
[
|
||||
|
@ -142,33 +145,37 @@ def _get_disc_decomp():
|
|||
)
|
||||
return decompositions_dict
|
||||
|
||||
|
||||
tensor_idx = 0
|
||||
features_out_hook = {}
|
||||
|
||||
|
||||
def hook(module, fea_in, fea_out):
|
||||
global features_out_hook, tensor_idx
|
||||
|
||||
if isinstance(fea_out, torch.Tensor):
|
||||
features_out_hook[f'f_{tensor_idx}'] = fea_out.detach().numpy()
|
||||
features_out_hook[f"f_{tensor_idx}"] = fea_out.detach().numpy()
|
||||
tensor_idx += 1
|
||||
return None
|
||||
|
||||
|
||||
def convert_module_fx(
|
||||
submodule_name: str,
|
||||
module: torch.fx.GraphModule,
|
||||
args={},
|
||||
bwd_graph:bool=False,
|
||||
para_shape: list=[],
|
||||
) :
|
||||
bwd_graph: bool = False,
|
||||
para_shape: list = [],
|
||||
):
|
||||
c = fx2mlir(submodule_name, args, bwd_graph, para_shape)
|
||||
return c.convert(module)
|
||||
|
||||
class SophonJointCompile:
|
||||
|
||||
class SophonJointCompile:
|
||||
def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None):
|
||||
fx_g, signature = aot_export_module(
|
||||
model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp()
|
||||
)
|
||||
fx_g.to_folder("yolov5sc","joint")
|
||||
fx_g.to_folder("yolov5sc", "joint")
|
||||
breakpoint()
|
||||
|
||||
def fx_convert_bmodel(self):
|
||||
|
@ -176,7 +183,6 @@ class SophonJointCompile:
|
|||
convert_module_fx(name, self.fx_g, self.args, False)
|
||||
|
||||
|
||||
|
||||
def train(hyp, opt, device, callbacks):
|
||||
"""
|
||||
Train a YOLOv5 model on a custom dataset using specified hyperparameters, options, and device, managing datasets,
|
||||
|
@ -429,9 +435,9 @@ def train(hyp, opt, device, callbacks):
|
|||
maps = np.zeros(nc) # mAP per class
|
||||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
||||
scheduler.last_epoch = start_epoch - 1 # do not move
|
||||
if opt.device == 'tpu':
|
||||
if opt.device == "tpu":
|
||||
scaler = torch_tpu.tpu.amp.GradScaler(enabled=True, allow_fp16=True)
|
||||
elif opt.device == 'cpu':
|
||||
elif opt.device == "cpu":
|
||||
scaler = None
|
||||
else:
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
||||
|
@ -444,7 +450,7 @@ def train(hyp, opt, device, callbacks):
|
|||
f"Logging results to {colorstr('bold', save_dir)}\n"
|
||||
f"Starting training for {epochs} epochs..."
|
||||
)
|
||||
|
||||
|
||||
hook_handles = []
|
||||
# dump_cuda_ref = True
|
||||
dump_cuda_ref = False
|
||||
|
@ -480,9 +486,9 @@ def train(hyp, opt, device, callbacks):
|
|||
if compiled:
|
||||
if joint:
|
||||
zwyjit = JitNet(model, compute_loss)
|
||||
model_opt = torch.compile(zwyjit, backend=aot_backend, dynamic = None, fullgraph = False)
|
||||
model_opt = torch.compile(zwyjit, backend=aot_backend, dynamic=None, fullgraph=False)
|
||||
else:
|
||||
model_opt = torch.compile(model, backend=aot_backend, dynamic = None, fullgraph = False)
|
||||
model_opt = torch.compile(model, backend=aot_backend, dynamic=None, fullgraph=False)
|
||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||
callbacks.run("on_train_batch_start")
|
||||
ni = i + nb * epoch # number integrated batches (since train start)
|
||||
|
@ -523,10 +529,10 @@ def train(hyp, opt, device, callbacks):
|
|||
if dump_cuda_ref:
|
||||
pred = model(imgs)
|
||||
global features_out_hook
|
||||
features_out_hook['data'] = imgs.detach().numpy()
|
||||
features_out_hook["data"] = imgs.detach().numpy()
|
||||
for name, param in model.named_parameters():
|
||||
features_out_hook[name] = param.detach().numpy()
|
||||
np.savez('layer_outputs.npz', **features_out_hook)
|
||||
np.savez("layer_outputs.npz", **features_out_hook)
|
||||
for hd in hook_handles:
|
||||
hd.remove()
|
||||
exit(0)
|
||||
|
@ -547,12 +553,12 @@ def train(hyp, opt, device, callbacks):
|
|||
loss *= 4.0
|
||||
|
||||
# Backward
|
||||
if opt.device == 'tpu':
|
||||
if opt.device == "tpu":
|
||||
loss = loss.to(device)
|
||||
#print('old loss:', loss, loss.device)
|
||||
#total_loss = 0.2666
|
||||
#loss.data.copy_(total_loss)
|
||||
print('loss:', loss, loss.device)
|
||||
# print('old loss:', loss, loss.device)
|
||||
# total_loss = 0.2666
|
||||
# loss.data.copy_(total_loss)
|
||||
print("loss:", loss, loss.device)
|
||||
if scaler is None:
|
||||
loss.backward()
|
||||
else:
|
||||
|
@ -1119,6 +1125,7 @@ def run(**kwargs):
|
|||
main(opt)
|
||||
return opt
|
||||
|
||||
|
||||
# import torch._dynamo
|
||||
# import logging
|
||||
# logger = logging.getLogger("torch._dynamo")
|
||||
|
@ -1131,8 +1138,9 @@ def run(**kwargs):
|
|||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
#print_ori_fx_graph/dump_fx_graph/skip_tpu_compile/dump_bmodel_input
|
||||
# print_ori_fx_graph/dump_fx_graph/skip_tpu_compile/dump_bmodel_input
|
||||
import tpu_mlir
|
||||
|
||||
tpu_mlir.python.tools.train.config.debug_cmd = opt.debug_cmd
|
||||
tpu_mlir.python.tools.train.config.compile_opt = 2
|
||||
# tpu_mlir.python.tools.train.config.only_compile_graph_id = 1
|
||||
|
@ -1140,5 +1148,5 @@ if __name__ == "__main__":
|
|||
tpu_mlir.python.tools.train.config.run_on_cmodel = True
|
||||
tpu_mlir.python.tools.train.config.print_config_info()
|
||||
# torch._dynamo.config.suppress_errors = True
|
||||
|
||||
|
||||
main(opt)
|
||||
|
|
|
@ -463,7 +463,7 @@ class ModelEMA:
|
|||
"""
|
||||
# self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
self.ema = de_parallel(model) # FP32 EMA
|
||||
self.ema.load_state_dict(model.state_dict())
|
||||
self.ema.load_state_dict(model.state_dict())
|
||||
self.ema = self.ema.eval()
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
|
|
Loading…
Reference in New Issue