commit
f951d36029
|
@ -68,6 +68,7 @@ class RandomIdentitySampler(Sampler):
|
||||||
if len(batch_idxs_dict[pid]) == 0:
|
if len(batch_idxs_dict[pid]) == 0:
|
||||||
avai_pids.remove(pid)
|
avai_pids.remove(pid)
|
||||||
|
|
||||||
|
self.length = len(final_idxs)
|
||||||
return iter(final_idxs)
|
return iter(final_idxs)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -14,6 +14,8 @@ from ignite.metrics import RunningAverage
|
||||||
|
|
||||||
from utils.reid_metric import R1_mAP
|
from utils.reid_metric import R1_mAP
|
||||||
|
|
||||||
|
global ITER
|
||||||
|
ITER = 0
|
||||||
|
|
||||||
def create_supervised_trainer(model, optimizer, loss_fn,
|
def create_supervised_trainer(model, optimizer, loss_fn,
|
||||||
device=None):
|
device=None):
|
||||||
|
@ -173,13 +175,16 @@ def do_train(
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED)
|
@trainer.on(Events.ITERATION_COMPLETED)
|
||||||
def log_training_loss(engine):
|
def log_training_loss(engine):
|
||||||
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
global ITER
|
||||||
|
ITER += 1
|
||||||
|
|
||||||
if iter % log_period == 0:
|
if ITER % log_period == 0:
|
||||||
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
||||||
.format(engine.state.epoch, iter, len(train_loader),
|
.format(engine.state.epoch, ITER, len(train_loader),
|
||||||
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
||||||
scheduler.get_lr()[0]))
|
scheduler.get_lr()[0]))
|
||||||
|
if len(train_loader) == ITER:
|
||||||
|
ITER = 0
|
||||||
|
|
||||||
# adding handlers using `trainer.on` decorator API
|
# adding handlers using `trainer.on` decorator API
|
||||||
@trainer.on(Events.EPOCH_COMPLETED)
|
@trainer.on(Events.EPOCH_COMPLETED)
|
||||||
|
@ -251,13 +256,16 @@ def do_train_with_center(
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED)
|
@trainer.on(Events.ITERATION_COMPLETED)
|
||||||
def log_training_loss(engine):
|
def log_training_loss(engine):
|
||||||
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
global ITER
|
||||||
|
ITER += 1
|
||||||
|
|
||||||
if iter % log_period == 0:
|
if ITER % log_period == 0:
|
||||||
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
||||||
.format(engine.state.epoch, iter, len(train_loader),
|
.format(engine.state.epoch, ITER, len(train_loader),
|
||||||
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
||||||
scheduler.get_lr()[0]))
|
scheduler.get_lr()[0]))
|
||||||
|
if len(train_loader) == ITER:
|
||||||
|
ITER = 0
|
||||||
|
|
||||||
# adding handlers using `trainer.on` decorator API
|
# adding handlers using `trainer.on` decorator API
|
||||||
@trainer.on(Events.EPOCH_COMPLETED)
|
@trainer.on(Events.EPOCH_COMPLETED)
|
||||||
|
|
Loading…
Reference in New Issue