relative patch location

pull/13/head
Jiahao000 2020-07-09 22:32:11 +08:00
parent 7a4de682a8
commit 972e5107ac
2 changed files with 21 additions and 14 deletions

View File

@ -39,6 +39,10 @@ test_pipeline = [
dict(type='Resize', size=292),
dict(type='CenterCrop', size=255),
]
format_pipeline = [
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]
data = dict(
imgs_per_gpu=64, # 64 x 8 = 512
workers_per_gpu=2,
@ -47,17 +51,21 @@ data = dict(
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline),
pipeline=train_pipeline,
format_pipeline=format_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list, root=data_test_root, **data_source_cfg),
pipeline=test_pipeline))
pipeline=test_pipeline,
format_pipeline=format_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.2, momentum=0.9, weight_decay=0.0001,
nesterov=False,
paramwise_options={'\Ahead.': dict(weight_decay=0.0005)})
paramwise_options={
'\Aneck.': dict(weight_decay=0.0005),
'\Ahead.': dict(weight_decay=0.0005)})
# learning policy
lr_config = dict(
policy='step',

View File

@ -1,9 +1,11 @@
from openselfsup.utils import build_from_cfg
import torch
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms import Compose, RandomCrop
import torchvision.transforms.functional as TF
from .registry import DATASETS
from .registry import DATASETS, PIPELINES
from .base import BaseDataset
@ -15,19 +17,13 @@ def image_to_patches(img):
w_grid = w // split_per_side
h_patch = h_grid - patch_jitter
w_patch = w_grid - patch_jitter
assert h_patch > 0 and w_patch > 0
patches = []
for i in range(split_per_side):
for j in range(split_per_side):
p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
if h_patch < h_grid or w_patch < w_grid:
p = transforms.RandomCrop((h_patch, w_patch))(p)
p = TF.to_tensor(p)
p = TF.normalize(p,
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
p = RandomCrop((h_patch, w_patch))(p)
patches.append(p)
return patches
@ -36,8 +32,10 @@ class RelativeLocDataset(BaseDataset):
"""Dataset for relative patch location
"""
def __init__(self, data_source, pipeline):
def __init__(self, data_source, pipeline, format_pipeline):
super(RelativeLocDataset, self).__init__(data_source, pipeline)
format_pipeline = [build_from_cfg(p, PIPELINES) for p in format_pipeline]
self.format_pipeline = Compose(format_pipeline)
def __getitem__(self, idx):
img = self.data_source.get_sample(idx)
@ -47,6 +45,7 @@ class RelativeLocDataset(BaseDataset):
type(img))
img = self.pipeline(img)
patches = image_to_patches(img)
patches = [self.format_pipeline(p) for p in patches]
perms = []
[perms.append(torch.cat((patches[i], patches[4]), dim=0)) for i in range(9) if i != 4]
patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])