Add files via upload

main
RE-OWOD 2022-01-11 11:05:19 +08:00 committed by GitHub
parent 3b865248f1
commit 9623f793bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 7 deletions

View File

@ -149,6 +149,8 @@ class GeneralizedRCNN(nn.Module):
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
image_id = [x["image_id"] for x in batched_inputs]
ori_image = [x['image'] for x in batched_inputs]
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
@ -157,17 +159,20 @@ class GeneralizedRCNN(nn.Module):
features = self.backbone(images.tensor)
if self.proposal_generator:
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
proposals, _ = self.proposal_generator(images, features, gt_instances, False)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
proposal_losses = {}
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals)
_, detector_losses, unk_sel_gt = self.roi_heads(images, features, proposals, gt_instances, image_id, ori_image)
_, proposal_losses = self.proposal_generator(images, features, gt_instances, True, unk_sel_gt)
# if self.vis_period > 0:
# storage = get_event_storage()
# if storage.iter % self.vis_period == 0:
# self.visualize_training(batched_inputs, proposals)
losses = {}
losses.update(detector_losses)
@ -203,7 +208,7 @@ class GeneralizedRCNN(nn.Module):
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
results, _ = self.roi_heads(images, features, proposals, None)
results, _, _ = self.roi_heads(images, features, proposals, None)
else:
detected_instances = [x.to(self.device) for x in detected_instances]
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)