add note, and temporarily tensorboard

This commit is contained in:
Xinlei Chen 2021-07-09 15:34:23 -07:00
parent 519aa28fb8
commit 84c7fc0fed
2 changed files with 27 additions and 13 deletions

View File

@ -14,9 +14,10 @@ import shutil
import time
import warnings
# ===== SLURM, to delete =====
# ===== to delete =====
import signal
# ===== SLURM, to delete =====
from tensorboardX import SummaryWriter
# ===== to delete =====
import torch
import torch.nn as nn
@ -38,7 +39,7 @@ import moco.builder
import moco.loader
import moco.optimizer
# ===== SLURM, to delete =====
# ===== to delete =====
def signalHandler(a, b):
if a == signal.SIGUSR1:
logger.info('Got SIGUSR1.')
@ -47,7 +48,7 @@ def signalHandler(a, b):
signal.signal(signal.SIGUSR1, signalHandler)
signal.signal(signal.SIGTERM, signalHandler)
# ===== SLURM, to delete =====
# ===== to delete =====
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
if name.islower() and not name.startswith("__")
@ -126,10 +127,10 @@ parser.add_argument('--optimizer', default='lars', type=str,
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
help='number of warmup epochs')
# ===== OTHERS, to delete =====
# ===== to delete =====
parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH',
help='path to save the checkpoints (default: .)')
# ===== OTHERS, to delete =====
# ===== to delete =====
def main():
args = parser.parse_args()
@ -153,10 +154,10 @@ def main():
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
# ===== PATH, to delete =====
# ===== to delete =====
if not os.path.exists(args.checkpoint_folder):
os.makedirs(args.checkpoint_folder)
# ===== PATH, to delete =====
# ===== to delete =====
ngpus_per_node = torch.cuda.device_count()
if args.multiprocessing_distributed:
@ -250,6 +251,10 @@ def main_worker(gpu, ngpus_per_node, args):
weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler()
# ===== to delete =====
if args.rank == 0:
summary_writer = SummaryWriter(logdir=args.checkpoint_folder)
# ===== to delete =====
# optionally resume from a checkpoint
if args.resume:
@ -324,7 +329,7 @@ def main_worker(gpu, ngpus_per_node, args):
adjust_learning_rate(optimizer, init_lr, epoch, args)
# train for one epoch
train(train_loader, model, optimizer, scaler, epoch, args)
train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank == 0): # only the first GPU saves checkpoint
@ -336,14 +341,18 @@ def main_worker(gpu, ngpus_per_node, args):
'scaler': scaler.state_dict(),
}, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch))
# ===== to delete =====
if args.rank == 0:
summary_writer.close()
# ===== to delete =====
def train(train_loader, model, optimizer, scaler, epoch, args):
def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
[batch_time, data_time, losses],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
@ -364,6 +373,10 @@ def train(train_loader, model, optimizer, scaler, epoch, args):
loss = model(images[0], images[1], moco_m)
losses.update(loss.item(), images[0].size(0))
# ===== to delete =====
if args.rank == 0:
summary_writer.add_scalar("loss", loss.item(), epoch * len(train_loader) + i)
# ===== to delete =====
# compute gradient and do SGD step
optimizer.zero_grad()

View File

@ -23,7 +23,6 @@ class MoCo(nn.Module):
super(MoCo, self).__init__()
self.T = T
self.criterion =
if with_vit:
self._init_encoders_with_vit(base_encoder, dim, mlp_dim)
@ -46,6 +45,8 @@ class MoCo(nn.Module):
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.ReLU(inplace=True))
else:
# similar to SimCLR: https://github.com/google-research/simclr/blob/master/model_util.py#L157
# remove this last BN also works
mlp.append(nn.BatchNorm1d(dim2, affine=False))
return nn.Sequential(*mlp)