From 6a2ba402291005b003a70e99f7c87d1a2c376b0d Mon Sep 17 00:00:00 2001 From: Kaiming He Date: Tue, 11 Jan 2022 12:31:55 -0800 Subject: [PATCH] add linear probing --- CONTRIBUTING.md | 2 +- FINETUNE.md | 46 ++++++- PRETRAIN.md | 2 +- main_linprobe.py | 316 +++++++++++++++++++++++++++++++++++++++++++ models_mae.py | 2 +- submitit_linprobe.py | 131 ++++++++++++++++++ util/crop.py | 42 ++++++ util/lars.py | 47 +++++++ 8 files changed, 584 insertions(+), 4 deletions(-) create mode 100644 main_linprobe.py create mode 100644 submitit_linprobe.py create mode 100644 util/crop.py create mode 100644 util/lars.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 273db47..40e570a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,7 +5,7 @@ possible. ## Pull Requests We actively welcome your pull requests. -1. Fork the repo and create your branch from `master`. +1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. diff --git a/FINETUNE.md b/FINETUNE.md index b289410..e396538 100644 --- a/FINETUNE.md +++ b/FINETUNE.md @@ -124,10 +124,54 @@ OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_fin ``` - Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`accum_iter`) * 8 (gpus) = 1024. `--accum_iter 4` simulates 4 nodes. -### Notes +#### Notes - The [pre-trained models we provide](https://github.com/fairinternal/mae/#pre-trained-checkpoints) are trained with *normalized* pixels `--norm_pix_loss` (1600 epochs, Table 3 in paper). The fine-tuning hyper-parameters are slightly different from the default baseline using *unnormalized* pixels. - The original MAE implementation was in TensorFlow+TPU with no explicit mixed precision. This re-implementation is in PyTorch+GPU with automatic mixed precision (`torch.cuda.amp`). We have observed different numerical behavior between the two platforms. In this repo, we use `--global_pool` for fine-tuning; using `--cls_token` performs similarly, but there is a chance of producing NaN when fine-tuning ViT-Huge in GPUs. We did not observe this issue in TPUs. Turning off amp could solve this issue, but is slower. - Here we use RandErase following DeiT: `--reprob 0.25`. Its effect is smaller than random variance. + +### Linear Probing + +Run the following on 4 nodes with 8 GPUs each: +``` +python submitit_linprobe.py \ + --job_dir ${JOB_DIR} \ + --nodes 4 \ + --batch_size 512 \ + --model vit_base_patch16 --cls_token \ + --finetune ${PRETRAIN_CHKPT} \ + --epochs 90 \ + --blr 0.1 \ + --weight_decay 0.0 \ + --dist_eval --data_path ${IMAGENET_DIR} +``` +- Here the effective batch size is 512 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 16384. +- `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. +- Training time is ~2h20m for 90 epochs in 32 V100 GPUs. +- To run single-node training, follow the instruction in fine-tuning. + +To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. It is sufficient to train 50 epochs `--epochs 50`. + +This PT/GPU code produces *better* results for ViT-L/H (see the table below). This is likely caused by the system difference between TF and PT. + + + + + + + + + + + + + + + + + + + +
ViT-BaseViT-LargeViT-Huge
paper (TF/TPU)68.075.876.6
this repo (PT/GPU)67.876.077.2
diff --git a/PRETRAIN.md b/PRETRAIN.md index bbdecb9..174d5b1 100644 --- a/PRETRAIN.md +++ b/PRETRAIN.md @@ -20,4 +20,4 @@ python submitit_pretrain.py \ - The exact same hyper-parameters and configs (initialization, augmentation, etc.) are used as our TF/TPU implementation. In our sanity checks, this PT/GPU re-implementation can reproduce the TF/TPU results within reasonable random variation. We get 85.5% [fine-tuning](FINETUNE.md) accuracy by pre-training ViT-Large for 800 epochs (85.4% in paper Table 1d with TF/TPU). - Training time is ~42h in 64 V100 GPUs (800 epochs). -To train ViT-Base or ViT-Huge, set `--model vit_base_patch16` or `--model vit_huge_patch14`. +To train ViT-Base or ViT-Huge, set `--model mae_vit_base_patch16` or `--model mae_vit_huge_patch14`. diff --git a/main_linprobe.py b/main_linprobe.py new file mode 100644 index 0000000..2d3f241 --- /dev/null +++ b/main_linprobe.py @@ -0,0 +1,316 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import argparse +import datetime +import json +import numpy as np +import os +import time +from pathlib import Path + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import timm + +assert timm.__version__ == "0.3.2" # version check +from timm.models.layers import trunc_normal_ + +import util.misc as misc +from util.pos_embed import interpolate_pos_embed +from util.misc import NativeScalerWithGradNormCount as NativeScaler +from util.lars import LARS +from util.crop import RandomResizedCrop + +import models_vit + +from engine_finetune import train_one_epoch, evaluate + + +def get_args_parser(): + parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) + parser.add_argument('--batch_size', default=512, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--epochs', default=90, type=int) + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + + # Model parameters + parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', + help='Name of model to train') + + # Optimizer parameters + parser.add_argument('--weight_decay', type=float, default=0, + help='weight decay (default: 0 for linear probe following MoCo v1)') + + parser.add_argument('--lr', type=float, default=None, metavar='LR', + help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=0.1, metavar='LR', + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + + parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', + help='epochs to warmup LR') + + # * Finetuning params + parser.add_argument('--finetune', default='', + help='finetune from checkpoint') + parser.add_argument('--global_pool', action='store_true') + parser.set_defaults(global_pool=False) + parser.add_argument('--cls_token', action='store_false', dest='global_pool', + help='Use class token instead of global pool for classification') + + # Dataset parameters + parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, + help='dataset path') + parser.add_argument('--nb_classes', default=1000, type=int, + help='number of the classification types') + + parser.add_argument('--output_dir', default='./output_dir', + help='path where to save, empty for no saving') + parser.add_argument('--log_dir', default='./output_dir', + help='path where to tensorboard log') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', + help='resume from checkpoint') + + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--dist_eval', action='store_true', default=False, + help='Enabling distributed evaluation (recommended during training for faster monitor') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + return parser + + +def main(args): + misc.init_distributed_mode(args) + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + # linear probe: weak augmentation + transform_train = transforms.Compose([ + RandomResizedCrop(224, interpolation=3), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + transform_val = transforms.Compose([ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) + dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) + print(dataset_train) + print(dataset_val) + + if True: # args.distributed: + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.log_dir is not None and not args.eval: + os.makedirs(args.log_dir, exist_ok=True) + log_writer = SummaryWriter(log_dir=args.log_dir) + else: + log_writer = None + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + model = models_vit.__dict__[args.model]( + num_classes=args.nb_classes, + global_pool=args.global_pool, + ) + + if args.finetune and not args.eval: + checkpoint = torch.load(args.finetune, map_location='cpu') + + print("Load pre-trained checkpoint from: %s" % args.finetune) + checkpoint_model = checkpoint['model'] + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # interpolate position embedding + interpolate_pos_embed(model, checkpoint_model) + + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + print(msg) + + if args.global_pool: + assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} + else: + assert set(msg.missing_keys) == {'head.weight', 'head.bias'} + + # manually initialize fc layer: following MoCo v3 + trunc_normal_(model.head.weight, std=0.01) + + # for linear prob only + # hack: revise model's head with BN + model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) + # freeze all but the head + for _, p in model.named_parameters(): + p.requires_grad = False + for _, p in model.head.named_parameters(): + p.requires_grad = True + + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print('number of params (M): %.2f' % (n_parameters / 1.e6)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) + print(optimizer) + loss_scaler = NativeScaler() + + criterion = torch.nn.CrossEntropyLoss() + + print("criterion = %s" % str(criterion)) + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if args.eval: + test_stats = evaluate(data_loader_val, model, device) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + exit(0) + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + max_norm=None, + log_writer=log_writer, + args=args + ) + if args.output_dir: + misc.save_model( + args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch) + + test_stats = evaluate(data_loader_val, model, device) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + max_accuracy = max(max_accuracy, test_stats["acc1"]) + print(f'Max accuracy: {max_accuracy:.2f}%') + + if log_writer is not None: + log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) + log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) + log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/models_mae.py b/models_mae.py index e1a3595..880e28f 100644 --- a/models_mae.py +++ b/models_mae.py @@ -55,7 +55,7 @@ class MaskedAutoencoderViT(nn.Module): for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) - self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch # -------------------------------------------------------------------------- self.norm_pix_loss = norm_pix_loss diff --git a/submitit_linprobe.py b/submitit_linprobe.py new file mode 100644 index 0000000..571186d --- /dev/null +++ b/submitit_linprobe.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# A script to run multinode training with submitit. +# -------------------------------------------------------- + +import argparse +import os +import uuid +from pathlib import Path + +import main_linprobe as classification +import submitit + + +def parse_args(): + classification_parser = classification.get_args_parser() + parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) + parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") + parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") + parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main_linprobe as classification + + self._setup_gpu_args() + classification.main(self.args) + + def checkpoint(self): + import os + import submitit + + self.args.dist_url = get_init_file().as_uri() + checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") + if os.path.exists(checkpoint_file): + self.args.resume = checkpoint_file + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.log_dir = self.args.output_dir + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=40 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name="mae") + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + # print("Submitted job_id:", job.job_id) + print(job.job_id) + + +if __name__ == "__main__": + main() diff --git a/util/crop.py b/util/crop.py new file mode 100644 index 0000000..fcb2612 --- /dev/null +++ b/util/crop.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torchvision import transforms +from torchvision.transforms import functional as F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + @staticmethod + def get_params(img, scale, ratio): + width, height = F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w \ No newline at end of file diff --git a/util/lars.py b/util/lars.py new file mode 100644 index 0000000..509c5f6 --- /dev/null +++ b/util/lars.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# LARS optimizer, implementation from MoCo v3: +# https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import torch + + +class LARS(torch.optim.Optimizer): + """ + LARS optimizer, no rate scaling or weight decay for parameters <= 1D. + """ + def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if p.ndim > 1: # if not normalization gamma/beta or bias + dp = dp.add(p, alpha=g['weight_decay']) + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['trust_coefficient'] * param_norm / update_norm), one), + one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + p.add_(mu, alpha=-g['lr']) \ No newline at end of file