mirror of https://github.com/FoundationVision/GLEE
151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
"""
|
|
An example config file to train a ImageNet classifier with detectron2.
|
|
Model and dataloader both come from torchvision.
|
|
This shows how to use detectron2 as a general engine for any new models and tasks.
|
|
|
|
To run, use the following command:
|
|
|
|
python tools/lazyconfig_train_net.py --config-file configs/Misc/torchvision_imagenet_R_50.py \
|
|
--num-gpus 8 dataloader.train.dataset.root=/path/to/imagenet/
|
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from omegaconf import OmegaConf
|
|
import torchvision
|
|
from torchvision.transforms import transforms as T
|
|
from torchvision.models.resnet import ResNet, Bottleneck
|
|
from fvcore.common.param_scheduler import MultiStepParamScheduler
|
|
|
|
from detectron2.solver import WarmupParamScheduler
|
|
from detectron2.solver.build import get_default_optimizer_params
|
|
from detectron2.config import LazyCall as L
|
|
from detectron2.model_zoo import get_config
|
|
from detectron2.data.samplers import TrainingSampler, InferenceSampler
|
|
from detectron2.evaluation import DatasetEvaluator
|
|
from detectron2.utils import comm
|
|
|
|
|
|
"""
|
|
Note: Here we put reusable code (models, evaluation, data) together with configs just as a
|
|
proof-of-concept, to easily demonstrate what's needed to train a ImageNet classifier in detectron2.
|
|
Writing code in configs offers extreme flexibility but is often not a good engineering practice.
|
|
In practice, you might want to put code in your project and import them instead.
|
|
"""
|
|
|
|
|
|
def build_data_loader(dataset, batch_size, num_workers, training=True):
|
|
return torch.utils.data.DataLoader(
|
|
dataset,
|
|
sampler=(TrainingSampler if training else InferenceSampler)(len(dataset)),
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
)
|
|
|
|
|
|
class ClassificationNet(nn.Module):
|
|
def __init__(self, model: nn.Module):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
@property
|
|
def device(self):
|
|
return list(self.model.parameters())[0].device
|
|
|
|
def forward(self, inputs):
|
|
image, label = inputs
|
|
pred = self.model(image.to(self.device))
|
|
if self.training:
|
|
label = label.to(self.device)
|
|
return F.cross_entropy(pred, label)
|
|
else:
|
|
return pred
|
|
|
|
|
|
class ClassificationAcc(DatasetEvaluator):
|
|
def reset(self):
|
|
self.corr = self.total = 0
|
|
|
|
def process(self, inputs, outputs):
|
|
image, label = inputs
|
|
self.corr += (outputs.argmax(dim=1).cpu() == label.cpu()).sum().item()
|
|
self.total += len(label)
|
|
|
|
def evaluate(self):
|
|
all_corr_total = comm.all_gather([self.corr, self.total])
|
|
corr = sum(x[0] for x in all_corr_total)
|
|
total = sum(x[1] for x in all_corr_total)
|
|
return {"accuracy": corr / total}
|
|
|
|
|
|
# --- End of code that could be in a project and be imported
|
|
|
|
|
|
dataloader = OmegaConf.create()
|
|
dataloader.train = L(build_data_loader)(
|
|
dataset=L(torchvision.datasets.ImageNet)(
|
|
root="/path/to/imagenet",
|
|
split="train",
|
|
transform=L(T.Compose)(
|
|
transforms=[
|
|
L(T.RandomResizedCrop)(size=224),
|
|
L(T.RandomHorizontalFlip)(),
|
|
T.ToTensor(),
|
|
L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
]
|
|
),
|
|
),
|
|
batch_size=256 // 8,
|
|
num_workers=4,
|
|
training=True,
|
|
)
|
|
|
|
dataloader.test = L(build_data_loader)(
|
|
dataset=L(torchvision.datasets.ImageNet)(
|
|
root="${...train.dataset.root}",
|
|
split="val",
|
|
transform=L(T.Compose)(
|
|
transforms=[
|
|
L(T.Resize)(size=256),
|
|
L(T.CenterCrop)(size=224),
|
|
T.ToTensor(),
|
|
L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
]
|
|
),
|
|
),
|
|
batch_size=256 // 8,
|
|
num_workers=4,
|
|
training=False,
|
|
)
|
|
|
|
dataloader.evaluator = L(ClassificationAcc)()
|
|
|
|
model = L(ClassificationNet)(
|
|
model=(ResNet)(block=Bottleneck, layers=[3, 4, 6, 3], zero_init_residual=True)
|
|
)
|
|
|
|
|
|
optimizer = L(torch.optim.SGD)(
|
|
params=L(get_default_optimizer_params)(),
|
|
lr=0.1,
|
|
momentum=0.9,
|
|
weight_decay=1e-4,
|
|
)
|
|
|
|
lr_multiplier = L(WarmupParamScheduler)(
|
|
scheduler=L(MultiStepParamScheduler)(
|
|
values=[1.0, 0.1, 0.01, 0.001], milestones=[30, 60, 90, 100]
|
|
),
|
|
warmup_length=1 / 100,
|
|
warmup_factor=0.1,
|
|
)
|
|
|
|
|
|
train = get_config("common/train.py").train
|
|
train.init_checkpoint = None
|
|
train.max_iter = 100 * 1281167 // 256
|