mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import time
|
|
import torch
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.cuda import amp
|
|
from fastreid.engine import DefaultTrainer
|
|
from .data_build import build_attr_train_loader, build_attr_test_loader
|
|
from .attr_evaluation import AttrEvaluator
|
|
|
|
|
|
class AttrTrainer(DefaultTrainer):
|
|
def __init__(self, cfg):
|
|
super().__init__(cfg)
|
|
|
|
# Sample weight for attributed imbalanced classification
|
|
bce_weight_enabled = self.cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED
|
|
# fmt: off
|
|
if bce_weight_enabled: self.sample_weights = self.data_loader.dataset.sample_weights.to("cuda")
|
|
else: self.sample_weights = None
|
|
# fmt: on
|
|
|
|
@classmethod
|
|
def build_train_loader(cls, cfg):
|
|
return build_attr_train_loader(cfg)
|
|
|
|
@classmethod
|
|
def build_test_loader(cls, cfg, dataset_name):
|
|
return build_attr_test_loader(cfg, dataset_name)
|
|
|
|
@classmethod
|
|
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
|
data_loader = cls.build_test_loader(cfg, dataset_name)
|
|
return data_loader, AttrEvaluator(cfg, output_folder)
|
|
|
|
def run_step(self):
|
|
r"""
|
|
Implement the attribute model training logic.
|
|
"""
|
|
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
|
start = time.perf_counter()
|
|
"""
|
|
If your want to do something with the data, you can wrap the dataloader.
|
|
"""
|
|
data = next(self._data_loader_iter)
|
|
data_time = time.perf_counter() - start
|
|
|
|
"""
|
|
If your want to do something with the heads, you can wrap the model.
|
|
"""
|
|
|
|
with amp.autocast(enabled=self.amp_enabled):
|
|
outs = self.model(data)
|
|
|
|
# Compute loss
|
|
if isinstance(self.model, DistributedDataParallel):
|
|
loss_dict = self.model.module.losses(outs, self.sample_weights)
|
|
else:
|
|
loss_dict = self.model.losses(outs, self.sample_weights)
|
|
|
|
losses = sum(loss_dict.values())
|
|
|
|
with torch.cuda.stream(torch.cuda.Stream()):
|
|
metrics_dict = loss_dict
|
|
metrics_dict["data_time"] = data_time
|
|
self._write_metrics(metrics_dict)
|
|
self._detect_anomaly(losses, loss_dict)
|
|
|
|
"""
|
|
If you need accumulate gradients or something similar, you can
|
|
wrap the optimizer with your custom `zero_grad()` method.
|
|
"""
|
|
self.optimizer.zero_grad()
|
|
|
|
if self.amp_enabled:
|
|
self.scaler.scale(losses).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
losses.backward()
|
|
"""
|
|
If you need gradient clipping/scaling or other processing, you can
|
|
wrap the optimizer with your custom `step()` method.
|
|
"""
|
|
self.optimizer.step()
|