mirror of https://github.com/RE-OWOD/RE-OWOD
Add files via upload
parent
3b865248f1
commit
9623f793bd
|
@ -149,6 +149,8 @@ class GeneralizedRCNN(nn.Module):
|
|||
return self.inference(batched_inputs)
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
image_id = [x["image_id"] for x in batched_inputs]
|
||||
ori_image = [x['image'] for x in batched_inputs]
|
||||
if "instances" in batched_inputs[0]:
|
||||
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
||||
else:
|
||||
|
@ -157,17 +159,20 @@ class GeneralizedRCNN(nn.Module):
|
|||
features = self.backbone(images.tensor)
|
||||
|
||||
if self.proposal_generator:
|
||||
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
|
||||
proposals, _ = self.proposal_generator(images, features, gt_instances, False)
|
||||
else:
|
||||
assert "proposals" in batched_inputs[0]
|
||||
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
||||
proposal_losses = {}
|
||||
|
||||
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
|
||||
if self.vis_period > 0:
|
||||
storage = get_event_storage()
|
||||
if storage.iter % self.vis_period == 0:
|
||||
self.visualize_training(batched_inputs, proposals)
|
||||
_, detector_losses, unk_sel_gt = self.roi_heads(images, features, proposals, gt_instances, image_id, ori_image)
|
||||
|
||||
_, proposal_losses = self.proposal_generator(images, features, gt_instances, True, unk_sel_gt)
|
||||
|
||||
# if self.vis_period > 0:
|
||||
# storage = get_event_storage()
|
||||
# if storage.iter % self.vis_period == 0:
|
||||
# self.visualize_training(batched_inputs, proposals)
|
||||
|
||||
losses = {}
|
||||
losses.update(detector_losses)
|
||||
|
@ -203,7 +208,7 @@ class GeneralizedRCNN(nn.Module):
|
|||
assert "proposals" in batched_inputs[0]
|
||||
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
||||
|
||||
results, _ = self.roi_heads(images, features, proposals, None)
|
||||
results, _, _ = self.roi_heads(images, features, proposals, None)
|
||||
else:
|
||||
detected_instances = [x.to(self.device) for x in detected_instances]
|
||||
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
|
||||
|
|
Loading…
Reference in New Issue