relative patch location
parent
7a4de682a8
commit
972e5107ac
|
@ -39,6 +39,10 @@ test_pipeline = [
|
|||
dict(type='Resize', size=292),
|
||||
dict(type='CenterCrop', size=255),
|
||||
]
|
||||
format_pipeline = [
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
data = dict(
|
||||
imgs_per_gpu=64, # 64 x 8 = 512
|
||||
workers_per_gpu=2,
|
||||
|
@ -47,17 +51,21 @@ data = dict(
|
|||
data_source=dict(
|
||||
list_file=data_train_list, root=data_train_root,
|
||||
**data_source_cfg),
|
||||
pipeline=train_pipeline),
|
||||
pipeline=train_pipeline,
|
||||
format_pipeline=format_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_test_list, root=data_test_root, **data_source_cfg),
|
||||
pipeline=test_pipeline))
|
||||
pipeline=test_pipeline,
|
||||
format_pipeline=format_pipeline))
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.2, momentum=0.9, weight_decay=0.0001,
|
||||
nesterov=False,
|
||||
paramwise_options={'\Ahead.': dict(weight_decay=0.0005)})
|
||||
paramwise_options={
|
||||
'\Aneck.': dict(weight_decay=0.0005),
|
||||
'\Ahead.': dict(weight_decay=0.0005)})
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from openselfsup.utils import build_from_cfg
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.transforms import Compose, RandomCrop
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
from .registry import DATASETS
|
||||
from .registry import DATASETS, PIPELINES
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
|
@ -15,19 +17,13 @@ def image_to_patches(img):
|
|||
w_grid = w // split_per_side
|
||||
h_patch = h_grid - patch_jitter
|
||||
w_patch = w_grid - patch_jitter
|
||||
|
||||
assert h_patch > 0 and w_patch > 0
|
||||
patches = []
|
||||
|
||||
for i in range(split_per_side):
|
||||
for j in range(split_per_side):
|
||||
p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
|
||||
if h_patch < h_grid or w_patch < w_grid:
|
||||
p = transforms.RandomCrop((h_patch, w_patch))(p)
|
||||
p = TF.to_tensor(p)
|
||||
p = TF.normalize(p,
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
p = RandomCrop((h_patch, w_patch))(p)
|
||||
patches.append(p)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
|
@ -36,8 +32,10 @@ class RelativeLocDataset(BaseDataset):
|
|||
"""Dataset for relative patch location
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, pipeline):
|
||||
def __init__(self, data_source, pipeline, format_pipeline):
|
||||
super(RelativeLocDataset, self).__init__(data_source, pipeline)
|
||||
format_pipeline = [build_from_cfg(p, PIPELINES) for p in format_pipeline]
|
||||
self.format_pipeline = Compose(format_pipeline)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = self.data_source.get_sample(idx)
|
||||
|
@ -47,6 +45,7 @@ class RelativeLocDataset(BaseDataset):
|
|||
type(img))
|
||||
img = self.pipeline(img)
|
||||
patches = image_to_patches(img)
|
||||
patches = [self.format_pipeline(p) for p in patches]
|
||||
perms = []
|
||||
[perms.append(torch.cat((patches[i], patches[4]), dim=0)) for i in range(9) if i != 4]
|
||||
patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
|
Loading…
Reference in New Issue