decoupled image processing from the main flow (#160)
parent
5bb6543346
commit
60d796825e
groundingdino/models/GroundingDINO
|
@ -206,6 +206,21 @@ class GroundingDINO(nn.Module):
|
|||
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
||||
nn.init.constant_(proj[0].bias, 0)
|
||||
|
||||
def set_image_tensor(self, samples: NestedTensor):
|
||||
if isinstance(samples, (list, torch.Tensor)):
|
||||
samples = nested_tensor_from_tensor_list(samples)
|
||||
self.features, self.poss = self.backbone(samples)
|
||||
|
||||
def unset_image_tensor(self):
|
||||
if hasattr(self, 'features'):
|
||||
del self.features
|
||||
if hasattr(self,'poss'):
|
||||
del self.poss
|
||||
|
||||
def set_image_features(self, features , poss):
|
||||
self.features = features
|
||||
self.poss = poss
|
||||
|
||||
def init_ref_points(self, use_num_queries):
|
||||
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
|
||||
|
||||
|
@ -282,14 +297,14 @@ class GroundingDINO(nn.Module):
|
|||
}
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if isinstance(samples, (list, torch.Tensor)):
|
||||
samples = nested_tensor_from_tensor_list(samples)
|
||||
features, poss = self.backbone(samples)
|
||||
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
|
||||
self.set_image_tensor(samples)
|
||||
|
||||
srcs = []
|
||||
masks = []
|
||||
for l, feat in enumerate(features):
|
||||
for l, feat in enumerate(self.features):
|
||||
src, mask = feat.decompose()
|
||||
srcs.append(self.input_proj[l](src))
|
||||
masks.append(mask)
|
||||
|
@ -298,7 +313,7 @@ class GroundingDINO(nn.Module):
|
|||
_len_srcs = len(srcs)
|
||||
for l in range(_len_srcs, self.num_feature_levels):
|
||||
if l == _len_srcs:
|
||||
src = self.input_proj[l](features[-1].tensors)
|
||||
src = self.input_proj[l](self.features[-1].tensors)
|
||||
else:
|
||||
src = self.input_proj[l](srcs[-1])
|
||||
m = samples.mask
|
||||
|
@ -306,11 +321,11 @@ class GroundingDINO(nn.Module):
|
|||
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
||||
srcs.append(src)
|
||||
masks.append(mask)
|
||||
poss.append(pos_l)
|
||||
self.poss.append(pos_l)
|
||||
|
||||
input_query_bbox = input_query_label = attn_mask = dn_meta = None
|
||||
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
|
||||
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
|
||||
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
|
||||
)
|
||||
|
||||
# deformable-detr-like anchor update
|
||||
|
@ -344,7 +359,9 @@ class GroundingDINO(nn.Module):
|
|||
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
|
||||
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
|
||||
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
|
||||
|
||||
unset_image_tensor = kw.get('unset_image_tensor', True)
|
||||
if unset_image_tensor:
|
||||
self.unset_image_tensor() ## If necessary
|
||||
return out
|
||||
|
||||
@torch.jit.unused
|
||||
|
@ -392,3 +409,4 @@ def build_groundingdino(args):
|
|||
)
|
||||
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue