mirror of https://github.com/RE-OWOD/RE-OWOD
Add files via upload
parent
3b865248f1
commit
9623f793bd
|
@ -149,6 +149,8 @@ class GeneralizedRCNN(nn.Module):
|
||||||
return self.inference(batched_inputs)
|
return self.inference(batched_inputs)
|
||||||
|
|
||||||
images = self.preprocess_image(batched_inputs)
|
images = self.preprocess_image(batched_inputs)
|
||||||
|
image_id = [x["image_id"] for x in batched_inputs]
|
||||||
|
ori_image = [x['image'] for x in batched_inputs]
|
||||||
if "instances" in batched_inputs[0]:
|
if "instances" in batched_inputs[0]:
|
||||||
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
||||||
else:
|
else:
|
||||||
|
@ -157,17 +159,20 @@ class GeneralizedRCNN(nn.Module):
|
||||||
features = self.backbone(images.tensor)
|
features = self.backbone(images.tensor)
|
||||||
|
|
||||||
if self.proposal_generator:
|
if self.proposal_generator:
|
||||||
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
|
proposals, _ = self.proposal_generator(images, features, gt_instances, False)
|
||||||
else:
|
else:
|
||||||
assert "proposals" in batched_inputs[0]
|
assert "proposals" in batched_inputs[0]
|
||||||
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
||||||
proposal_losses = {}
|
proposal_losses = {}
|
||||||
|
|
||||||
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
|
_, detector_losses, unk_sel_gt = self.roi_heads(images, features, proposals, gt_instances, image_id, ori_image)
|
||||||
if self.vis_period > 0:
|
|
||||||
storage = get_event_storage()
|
_, proposal_losses = self.proposal_generator(images, features, gt_instances, True, unk_sel_gt)
|
||||||
if storage.iter % self.vis_period == 0:
|
|
||||||
self.visualize_training(batched_inputs, proposals)
|
# if self.vis_period > 0:
|
||||||
|
# storage = get_event_storage()
|
||||||
|
# if storage.iter % self.vis_period == 0:
|
||||||
|
# self.visualize_training(batched_inputs, proposals)
|
||||||
|
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.update(detector_losses)
|
losses.update(detector_losses)
|
||||||
|
@ -203,7 +208,7 @@ class GeneralizedRCNN(nn.Module):
|
||||||
assert "proposals" in batched_inputs[0]
|
assert "proposals" in batched_inputs[0]
|
||||||
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
||||||
|
|
||||||
results, _ = self.roi_heads(images, features, proposals, None)
|
results, _, _ = self.roi_heads(images, features, proposals, None)
|
||||||
else:
|
else:
|
||||||
detected_instances = [x.to(self.device) for x in detected_instances]
|
detected_instances = [x.to(self.device) for x in detected_instances]
|
||||||
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
|
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
|
||||||
|
|
Loading…
Reference in New Issue