mirror of https://github.com/JDAI-CV/fast-reid.git
style: fix some typro
parent
2ac55a7601
commit
e990cf3e34
|
@ -292,8 +292,9 @@ class DefaultTrainer(SimpleTrainer):
|
|||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
))
|
||||
|
||||
if cfg.MODEL.OPEN_LAYERS[0] != '' and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
logger.info(f"Freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
|
||||
if cfg.MODEL.OPEN_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
open_layers = ",".join(cfg.MODEL.OPEN_LAYERS)
|
||||
logger.info(f'Open "{open_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iters')
|
||||
ret.append(hooks.FreezeLayer(
|
||||
self.model,
|
||||
cfg.MODEL.OPEN_LAYERS,
|
||||
|
|
|
@ -441,18 +441,18 @@ class FreezeLayer(HookBase):
|
|||
self.param_grad = param_grad
|
||||
|
||||
def before_step(self):
|
||||
# freeze specific layers
|
||||
# Freeze specific layers
|
||||
if self.trainer.iter < self.freeze_iters:
|
||||
self.freeze_specific_layer()
|
||||
|
||||
# recover original layers status
|
||||
# Recover original layers status
|
||||
elif self.trainer.iter == self.freeze_iters:
|
||||
self.open_all_layer()
|
||||
|
||||
def freeze_specific_layer(self):
|
||||
for layer in self.open_layer_names:
|
||||
if not hasattr(self.model, layer):
|
||||
self._logger.info('"{}" is not an attribute of the model, will skip this layer'.format(layer))
|
||||
self._logger.info(f'"{layer}" is not an attribute of the model, will skip this layer')
|
||||
|
||||
for name, module in self.model.named_children():
|
||||
if name in self.open_layer_names:
|
||||
|
|
|
@ -5,8 +5,9 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
@ -14,7 +15,6 @@ from scipy.stats import norm
|
|||
from sklearn import metrics
|
||||
|
||||
from .file_io import PathManager
|
||||
import random
|
||||
|
||||
|
||||
class Visualizer:
|
||||
|
@ -236,6 +236,7 @@ class Visualizer:
|
|||
def load_roc_info(path):
|
||||
with open(path, 'rb') as handle: res = pickle.load(handle)
|
||||
return res
|
||||
|
||||
# def plot_camera_dist(self):
|
||||
# same_cam, diff_cam = [], []
|
||||
# for i, q in enumerate(self.q_pids):
|
||||
|
|
Loading…
Reference in New Issue