decoupled image processing from the main flow (#160)
parent
5bb6543346
commit
60d796825e
|
@ -206,6 +206,21 @@ class GroundingDINO(nn.Module):
|
||||||
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
||||||
nn.init.constant_(proj[0].bias, 0)
|
nn.init.constant_(proj[0].bias, 0)
|
||||||
|
|
||||||
|
def set_image_tensor(self, samples: NestedTensor):
|
||||||
|
if isinstance(samples, (list, torch.Tensor)):
|
||||||
|
samples = nested_tensor_from_tensor_list(samples)
|
||||||
|
self.features, self.poss = self.backbone(samples)
|
||||||
|
|
||||||
|
def unset_image_tensor(self):
|
||||||
|
if hasattr(self, 'features'):
|
||||||
|
del self.features
|
||||||
|
if hasattr(self,'poss'):
|
||||||
|
del self.poss
|
||||||
|
|
||||||
|
def set_image_features(self, features , poss):
|
||||||
|
self.features = features
|
||||||
|
self.poss = poss
|
||||||
|
|
||||||
def init_ref_points(self, use_num_queries):
|
def init_ref_points(self, use_num_queries):
|
||||||
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
|
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
|
||||||
|
|
||||||
|
@ -282,14 +297,14 @@ class GroundingDINO(nn.Module):
|
||||||
}
|
}
|
||||||
|
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
if isinstance(samples, (list, torch.Tensor)):
|
if isinstance(samples, (list, torch.Tensor)):
|
||||||
samples = nested_tensor_from_tensor_list(samples)
|
samples = nested_tensor_from_tensor_list(samples)
|
||||||
features, poss = self.backbone(samples)
|
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
|
||||||
|
self.set_image_tensor(samples)
|
||||||
|
|
||||||
srcs = []
|
srcs = []
|
||||||
masks = []
|
masks = []
|
||||||
for l, feat in enumerate(features):
|
for l, feat in enumerate(self.features):
|
||||||
src, mask = feat.decompose()
|
src, mask = feat.decompose()
|
||||||
srcs.append(self.input_proj[l](src))
|
srcs.append(self.input_proj[l](src))
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
@ -298,7 +313,7 @@ class GroundingDINO(nn.Module):
|
||||||
_len_srcs = len(srcs)
|
_len_srcs = len(srcs)
|
||||||
for l in range(_len_srcs, self.num_feature_levels):
|
for l in range(_len_srcs, self.num_feature_levels):
|
||||||
if l == _len_srcs:
|
if l == _len_srcs:
|
||||||
src = self.input_proj[l](features[-1].tensors)
|
src = self.input_proj[l](self.features[-1].tensors)
|
||||||
else:
|
else:
|
||||||
src = self.input_proj[l](srcs[-1])
|
src = self.input_proj[l](srcs[-1])
|
||||||
m = samples.mask
|
m = samples.mask
|
||||||
|
@ -306,11 +321,11 @@ class GroundingDINO(nn.Module):
|
||||||
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
||||||
srcs.append(src)
|
srcs.append(src)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
poss.append(pos_l)
|
self.poss.append(pos_l)
|
||||||
|
|
||||||
input_query_bbox = input_query_label = attn_mask = dn_meta = None
|
input_query_bbox = input_query_label = attn_mask = dn_meta = None
|
||||||
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
|
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
|
||||||
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
|
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
# deformable-detr-like anchor update
|
# deformable-detr-like anchor update
|
||||||
|
@ -344,7 +359,9 @@ class GroundingDINO(nn.Module):
|
||||||
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
|
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
|
||||||
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
|
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
|
||||||
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
|
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
|
||||||
|
unset_image_tensor = kw.get('unset_image_tensor', True)
|
||||||
|
if unset_image_tensor:
|
||||||
|
self.unset_image_tensor() ## If necessary
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
|
@ -392,3 +409,4 @@ def build_groundingdino(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue