linear weights
parent
9bebc3fe8e
commit
d2f3156bb3
27
README.md
27
README.md
|
@ -287,6 +287,12 @@ We release the logs and weights from evaluating the different models:
|
|||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth">linear weights</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">logs</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/8</td>
|
||||
<td>80.1%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth">linear weights</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">logs</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>xcit_small_12_p16</td>
|
||||
<td>77.8%</td>
|
||||
|
@ -319,6 +325,27 @@ We release the logs and weights from evaluating the different models:
|
|||
</tr>
|
||||
</table>
|
||||
|
||||
You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:
|
||||
```
|
||||
python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
|
||||
```
|
||||
|
||||
```
|
||||
python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
|
||||
```
|
||||
|
||||
```
|
||||
python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
|
||||
```
|
||||
|
||||
```
|
||||
python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
|
||||
```
|
||||
|
||||
```
|
||||
python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train
|
||||
```
|
||||
|
||||
## Evaluation: DAVIS 2017 Video object segmentation
|
||||
Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.
|
||||
|
||||
|
|
|
@ -34,37 +34,6 @@ def eval_linear(args):
|
|||
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
||||
cudnn.benchmark = True
|
||||
|
||||
# ============ preparing data ... ============
|
||||
train_transform = pth_transforms.Compose([
|
||||
pth_transforms.RandomResizedCrop(224),
|
||||
pth_transforms.RandomHorizontalFlip(),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
val_transform = pth_transforms.Compose([
|
||||
pth_transforms.Resize(256, interpolation=3),
|
||||
pth_transforms.CenterCrop(224),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
|
||||
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset_train,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
dataset_val,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
|
||||
|
||||
# ============ building network ... ============
|
||||
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
|
||||
if args.arch in vits.__dict__.keys():
|
||||
|
@ -92,6 +61,44 @@ def eval_linear(args):
|
|||
linear_classifier = linear_classifier.cuda()
|
||||
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
|
||||
|
||||
# ============ preparing data ... ============
|
||||
val_transform = pth_transforms.Compose([
|
||||
pth_transforms.Resize(256, interpolation=3),
|
||||
pth_transforms.CenterCrop(224),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
dataset_val,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
if args.evaluate:
|
||||
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
|
||||
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
return
|
||||
|
||||
train_transform = pth_transforms.Compose([
|
||||
pth_transforms.RandomResizedCrop(224),
|
||||
pth_transforms.RandomHorizontalFlip(),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset_train,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
|
||||
|
||||
# set optimizer
|
||||
optimizer = torch.optim.SGD(
|
||||
linear_classifier.parameters(),
|
||||
|
@ -157,10 +164,10 @@ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
|
|||
with torch.no_grad():
|
||||
if "vit" in args.arch:
|
||||
intermediate_output = model.get_intermediate_layers(inp, n)
|
||||
output = [x[:, 0] for x in intermediate_output]
|
||||
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
|
||||
if avgpool:
|
||||
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
|
||||
output = torch.cat(output, dim=-1)
|
||||
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
else:
|
||||
output = model(inp)
|
||||
output = linear_classifier(output)
|
||||
|
@ -199,10 +206,10 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool):
|
|||
with torch.no_grad():
|
||||
if "vit" in args.arch:
|
||||
intermediate_output = model.get_intermediate_layers(inp, n)
|
||||
output = [x[:, 0] for x in intermediate_output]
|
||||
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
|
||||
if avgpool:
|
||||
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
|
||||
output = torch.cat(output, dim=-1)
|
||||
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
else:
|
||||
output = model(inp)
|
||||
output = linear_classifier(output)
|
||||
|
@ -269,5 +276,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
|
||||
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
|
||||
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
|
||||
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
|
||||
args = parser.parse_args()
|
||||
eval_linear(args)
|
||||
|
|
20
utils.py
20
utils.py
|
@ -109,6 +109,26 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_nam
|
|||
print("There is no reference weights available for this model => We use random weights.")
|
||||
|
||||
|
||||
def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
|
||||
url = None
|
||||
if model_name == "vit_small" and patch_size == 16:
|
||||
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
|
||||
elif model_name == "vit_small" and patch_size == 8:
|
||||
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
|
||||
elif model_name == "vit_base" and patch_size == 16:
|
||||
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
|
||||
elif model_name == "vit_base" and patch_size == 8:
|
||||
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
|
||||
elif model_name == "resnet50":
|
||||
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
|
||||
if url is not None:
|
||||
print("We load the reference pretrained linear weights.")
|
||||
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
|
||||
linear_classifier.load_state_dict(state_dict, strict=True)
|
||||
else:
|
||||
print("We use random linear weights.")
|
||||
|
||||
|
||||
def clip_gradients(model, clip):
|
||||
norms = []
|
||||
for name, p in model.named_parameters():
|
||||
|
|
Loading…
Reference in New Issue