decoupled image processing from the main flow ()

pull/191/head
jishnujp-vp 2023-07-23 10:38:59 +05:30 committed by GitHub
parent 5bb6543346
commit 60d796825e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 7 deletions
groundingdino/models/GroundingDINO

View File

@ -206,6 +206,21 @@ class GroundingDINO(nn.Module):
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def set_image_tensor(self, samples: NestedTensor):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
self.features, self.poss = self.backbone(samples)
def unset_image_tensor(self):
if hasattr(self, 'features'):
del self.features
if hasattr(self,'poss'):
del self.poss
def set_image_features(self, features , poss):
self.features = features
self.poss = poss
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
@ -282,14 +297,14 @@ class GroundingDINO(nn.Module):
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, poss = self.backbone(samples)
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
self.set_image_tensor(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
for l, feat in enumerate(self.features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
@ -298,7 +313,7 @@ class GroundingDINO(nn.Module):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
src = self.input_proj[l](self.features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
@ -306,11 +321,11 @@ class GroundingDINO(nn.Module):
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
self.poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
@ -344,7 +359,9 @@ class GroundingDINO(nn.Module):
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
unset_image_tensor = kw.get('unset_image_tensor', True)
if unset_image_tensor:
self.unset_image_tensor() ## If necessary
return out
@torch.jit.unused
@ -392,3 +409,4 @@ def build_groundingdino(args):
)
return model