import torch as t import torchvision.models as model import torch.nn.functional as F from util.largestConnectComponent import largestConnectComponent """ feat_map = t.ones(5, 2, 4) print(feat_map) A = feat_map.sum(dim=0) print(A) a = A.mean(dim=[0, 1]) print(float(a)) """ def select_aggregate(feat_map): A = t.sum(feat_map, dim=[0]) a = t.mean(A, dim=[0, 1]).float() tmp = t.ones(A.shape) tmp[A