diff --git a/README.md b/README.md
index 94529d0..72c2809 100644
--- a/README.md
+++ b/README.md
@@ -287,6 +287,12 @@ We release the logs and weights from evaluating the different models:
linear weights |
logs |
+
+ ViT-B/8 |
+ 80.1% |
+ linear weights |
+ logs |
+
xcit_small_12_p16 |
77.8% |
@@ -319,6 +325,27 @@ We release the logs and weights from evaluating the different models:
+You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:
+```
+python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
+```
+
+```
+python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
+```
+
+```
+python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
+```
+
+```
+python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
+```
+
+```
+python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train
+```
+
## Evaluation: DAVIS 2017 Video object segmentation
Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.
diff --git a/eval_linear.py b/eval_linear.py
index e95315b..81eb94f 100644
--- a/eval_linear.py
+++ b/eval_linear.py
@@ -34,37 +34,6 @@ def eval_linear(args):
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
cudnn.benchmark = True
- # ============ preparing data ... ============
- train_transform = pth_transforms.Compose([
- pth_transforms.RandomResizedCrop(224),
- pth_transforms.RandomHorizontalFlip(),
- pth_transforms.ToTensor(),
- pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- ])
- val_transform = pth_transforms.Compose([
- pth_transforms.Resize(256, interpolation=3),
- pth_transforms.CenterCrop(224),
- pth_transforms.ToTensor(),
- pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- ])
- dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
- dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
- sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
- train_loader = torch.utils.data.DataLoader(
- dataset_train,
- sampler=sampler,
- batch_size=args.batch_size_per_gpu,
- num_workers=args.num_workers,
- pin_memory=True,
- )
- val_loader = torch.utils.data.DataLoader(
- dataset_val,
- batch_size=args.batch_size_per_gpu,
- num_workers=args.num_workers,
- pin_memory=True,
- )
- print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
-
# ============ building network ... ============
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
if args.arch in vits.__dict__.keys():
@@ -92,6 +61,44 @@ def eval_linear(args):
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
+ # ============ preparing data ... ============
+ val_transform = pth_transforms.Compose([
+ pth_transforms.Resize(256, interpolation=3),
+ pth_transforms.CenterCrop(224),
+ pth_transforms.ToTensor(),
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ ])
+ dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
+ val_loader = torch.utils.data.DataLoader(
+ dataset_val,
+ batch_size=args.batch_size_per_gpu,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ )
+
+ if args.evaluate:
+ utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
+ test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
+ print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+ return
+
+ train_transform = pth_transforms.Compose([
+ pth_transforms.RandomResizedCrop(224),
+ pth_transforms.RandomHorizontalFlip(),
+ pth_transforms.ToTensor(),
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ ])
+ dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
+ train_loader = torch.utils.data.DataLoader(
+ dataset_train,
+ sampler=sampler,
+ batch_size=args.batch_size_per_gpu,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ )
+ print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
+
# set optimizer
optimizer = torch.optim.SGD(
linear_classifier.parameters(),
@@ -157,10 +164,10 @@ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
with torch.no_grad():
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
- output = [x[:, 0] for x in intermediate_output]
+ output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
- output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
- output = torch.cat(output, dim=-1)
+ output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
+ output = output.reshape(output.shape[0], -1)
else:
output = model(inp)
output = linear_classifier(output)
@@ -199,10 +206,10 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool):
with torch.no_grad():
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
- output = [x[:, 0] for x in intermediate_output]
+ output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
- output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
- output = torch.cat(output, dim=-1)
+ output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
+ output = output.reshape(output.shape[0], -1)
else:
output = model(inp)
output = linear_classifier(output)
@@ -269,5 +276,6 @@ if __name__ == '__main__':
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
+ parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
args = parser.parse_args()
eval_linear(args)
diff --git a/utils.py b/utils.py
index 978d79d..9586250 100644
--- a/utils.py
+++ b/utils.py
@@ -109,6 +109,26 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_nam
print("There is no reference weights available for this model => We use random weights.")
+def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
+ url = None
+ if model_name == "vit_small" and patch_size == 16:
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
+ elif model_name == "vit_small" and patch_size == 8:
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
+ elif model_name == "vit_base" and patch_size == 16:
+ url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
+ elif model_name == "vit_base" and patch_size == 8:
+ url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
+ elif model_name == "resnet50":
+ url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
+ if url is not None:
+ print("We load the reference pretrained linear weights.")
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
+ linear_classifier.load_state_dict(state_dict, strict=True)
+ else:
+ print("We use random linear weights.")
+
+
def clip_gradients(model, clip):
norms = []
for name, p in model.named_parameters():