mirror of
https://github.com/facebookresearch/moco-v3.git
synced 2025-06-03 14:59:22 +08:00
add note, and temporarily tensorboard
This commit is contained in:
parent
519aa28fb8
commit
84c7fc0fed
35
main_moco.py
35
main_moco.py
@ -14,9 +14,10 @@ import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
# ===== SLURM, to delete =====
|
||||
# ===== to delete =====
|
||||
import signal
|
||||
# ===== SLURM, to delete =====
|
||||
from tensorboardX import SummaryWriter
|
||||
# ===== to delete =====
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -38,7 +39,7 @@ import moco.builder
|
||||
import moco.loader
|
||||
import moco.optimizer
|
||||
|
||||
# ===== SLURM, to delete =====
|
||||
# ===== to delete =====
|
||||
def signalHandler(a, b):
|
||||
if a == signal.SIGUSR1:
|
||||
logger.info('Got SIGUSR1.')
|
||||
@ -47,7 +48,7 @@ def signalHandler(a, b):
|
||||
|
||||
signal.signal(signal.SIGUSR1, signalHandler)
|
||||
signal.signal(signal.SIGTERM, signalHandler)
|
||||
# ===== SLURM, to delete =====
|
||||
# ===== to delete =====
|
||||
|
||||
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
@ -126,10 +127,10 @@ parser.add_argument('--optimizer', default='lars', type=str,
|
||||
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
|
||||
help='number of warmup epochs')
|
||||
|
||||
# ===== OTHERS, to delete =====
|
||||
# ===== to delete =====
|
||||
parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH',
|
||||
help='path to save the checkpoints (default: .)')
|
||||
# ===== OTHERS, to delete =====
|
||||
# ===== to delete =====
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
@ -153,10 +154,10 @@ def main():
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
# ===== PATH, to delete =====
|
||||
# ===== to delete =====
|
||||
if not os.path.exists(args.checkpoint_folder):
|
||||
os.makedirs(args.checkpoint_folder)
|
||||
# ===== PATH, to delete =====
|
||||
# ===== to delete =====
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
@ -250,6 +251,10 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
# ===== to delete =====
|
||||
if args.rank == 0:
|
||||
summary_writer = SummaryWriter(logdir=args.checkpoint_folder)
|
||||
# ===== to delete =====
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
@ -324,7 +329,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, model, optimizer, scaler, epoch, args)
|
||||
train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
|
||||
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank == 0): # only the first GPU saves checkpoint
|
||||
@ -336,14 +341,18 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
'scaler': scaler.state_dict(),
|
||||
}, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch))
|
||||
|
||||
# ===== to delete =====
|
||||
if args.rank == 0:
|
||||
summary_writer.close()
|
||||
# ===== to delete =====
|
||||
|
||||
def train(train_loader, model, optimizer, scaler, epoch, args):
|
||||
def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
data_time = AverageMeter('Data', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
progress = ProgressMeter(
|
||||
len(train_loader),
|
||||
[batch_time, data_time, losses, top1, top5],
|
||||
[batch_time, data_time, losses],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
# switch to train mode
|
||||
@ -364,6 +373,10 @@ def train(train_loader, model, optimizer, scaler, epoch, args):
|
||||
loss = model(images[0], images[1], moco_m)
|
||||
|
||||
losses.update(loss.item(), images[0].size(0))
|
||||
# ===== to delete =====
|
||||
if args.rank == 0:
|
||||
summary_writer.add_scalar("loss", loss.item(), epoch * len(train_loader) + i)
|
||||
# ===== to delete =====
|
||||
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
|
@ -23,7 +23,6 @@ class MoCo(nn.Module):
|
||||
super(MoCo, self).__init__()
|
||||
|
||||
self.T = T
|
||||
self.criterion =
|
||||
|
||||
if with_vit:
|
||||
self._init_encoders_with_vit(base_encoder, dim, mlp_dim)
|
||||
@ -46,6 +45,8 @@ class MoCo(nn.Module):
|
||||
mlp.append(nn.BatchNorm1d(dim2))
|
||||
mlp.append(nn.ReLU(inplace=True))
|
||||
else:
|
||||
# similar to SimCLR: https://github.com/google-research/simclr/blob/master/model_util.py#L157
|
||||
# remove this last BN also works
|
||||
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
||||
|
||||
return nn.Sequential(*mlp)
|
||||
|
Loading…
x
Reference in New Issue
Block a user