mirror of
https://github.com/PyRetri/PyRetri.git
synced 2025-06-03 14:49:50 +08:00
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import queue
|
|
|
|
import torch
|
|
|
|
from ..aggregators_base import AggregatorBase
|
|
from ...registry import AGGREGATORS
|
|
|
|
from typing import Dict
|
|
|
|
|
|
@AGGREGATORS.register
|
|
class SCDA(AggregatorBase):
|
|
"""
|
|
Selective Convolutional Descriptor Aggregation for Fine-Grained Image Retrieval.
|
|
c.f. http://www.weixiushen.com/publication/tip17SCDA.pdf
|
|
"""
|
|
default_hyper_params = dict()
|
|
|
|
def __init__(self, hps: Dict or None = None):
|
|
"""
|
|
Args:
|
|
hps (dict): default hyper parameters in a dict (keys, values).
|
|
"""
|
|
super(SCDA, self).__init__(hps)
|
|
self.first_show = True
|
|
|
|
def bfs(self, x: int, y: int, mask: torch.tensor, cc_map: torch.tensor, cc_id: int) -> int:
|
|
dirs = [[1, 0], [-1, 0], [0, 1], [0, -1]]
|
|
q = queue.LifoQueue()
|
|
q.put((x, y))
|
|
|
|
ret = 1
|
|
cc_map[x][y] = cc_id
|
|
|
|
while not q.empty():
|
|
x, y = q.get()
|
|
|
|
for (dx, dy) in dirs:
|
|
new_x = x + dx
|
|
new_y = y + dy
|
|
if 0 <= new_x < mask.shape[0] and 0 <= new_y < mask.shape[1]:
|
|
if mask[new_x][new_y] == 1 and cc_map[new_x][new_y] == 0:
|
|
q.put((new_x, new_y))
|
|
ret += 1
|
|
cc_map[new_x][new_y] = cc_id
|
|
return ret
|
|
|
|
def find_max_cc(self, mask: torch.tensor) -> torch.tensor:
|
|
"""
|
|
Find the largest connected component of the mask。
|
|
|
|
Args:
|
|
mask (torch.tensor): the original mask.
|
|
|
|
Returns:
|
|
mask (torch.tensor): the mask only containing the maximum connected component.
|
|
"""
|
|
assert mask.ndim == 4
|
|
assert mask.shape[1] == 1
|
|
mask = mask[:, 0, :, :]
|
|
for i in range(mask.shape[0]):
|
|
m = mask[i]
|
|
cc_map = torch.zeros(m.shape)
|
|
cc_num = list()
|
|
|
|
for x in range(m.shape[0]):
|
|
for y in range(m.shape[1]):
|
|
if m[x][y] == 1 and cc_map[x][y] == 0:
|
|
cc_id = len(cc_num) + 1
|
|
cc_num.append(self.bfs(x, y, m, cc_map, cc_id))
|
|
|
|
max_cc_id = cc_num.index(max(cc_num)) + 1
|
|
m[cc_map != max_cc_id] = 0
|
|
mask = mask[:, None, :, :]
|
|
return mask
|
|
|
|
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
|
ret = dict()
|
|
for key in features:
|
|
fea = features[key]
|
|
if fea.ndimension() == 4:
|
|
mask = fea.sum(dim=1, keepdims=True)
|
|
thres = mask.mean(dim=(2, 3), keepdims=True)
|
|
mask[mask <= thres] = 0
|
|
mask[mask > thres] = 1
|
|
mask = self.find_max_cc(mask)
|
|
fea = fea * mask
|
|
|
|
gap = fea.mean(dim=(2, 3))
|
|
gmp, _ = fea.max(dim=3)
|
|
gmp, _ = gmp.max(dim=2)
|
|
|
|
ret[key + "_{}".format(self.__class__.__name__)] = torch.cat([gap, gmp], dim=1)
|
|
else:
|
|
# In case of fc feature.
|
|
assert fea.ndimension() == 2
|
|
if self.first_show:
|
|
print("[SCDA Aggregator]: find 2-dimension feature map, skip aggregation")
|
|
self.first_show = False
|
|
ret[key] = fea
|
|
return ret
|