From 37e13f88467d29bfb5461ae041fa88e211fbc7ca Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Sat, 27 Jun 2020 13:50:15 -0700
Subject: [PATCH] update mosaic border

---
 train.py          | 4 ++++
 utils/datasets.py | 7 +++----
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/train.py b/train.py
index 4238713fb..a572a3762 100644
--- a/train.py
+++ b/train.py
@@ -207,6 +207,10 @@ def train(hyp):
             image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
             dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n)  # rand weighted idx
 
+        # Update mosaic border
+        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
+        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders
+
         mloss = torch.zeros(4, device=device)  # mean losses
         print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
         pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar
diff --git a/utils/datasets.py b/utils/datasets.py
index 69bed4710..1d423e4ca 100755
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -307,7 +307,7 @@ class LoadImagesAndLabels(Dataset):  # for training/testing
         self.image_weights = image_weights
         self.rect = False if image_weights else rect
         self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
-        self.mosaic_border = None
+        self.mosaic_border = [-img_size // 2, -img_size // 2]
         self.stride = stride
 
 
@@ -588,8 +588,7 @@ def load_mosaic(self, index):
 
     labels4 = []
     s = self.img_size
-    border = [-s // 2, -s // 2]  # self.mosaic_border
-    yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in border]  # mosaic center x, y
+    yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border]  # mosaic center x, y
     indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)]  # 3 additional image indices
     for i, index in enumerate(indices):
         # Load image
@@ -637,7 +636,7 @@ def load_mosaic(self, index):
                                   translate=self.hyp['translate'],
                                   scale=self.hyp['scale'],
                                   shear=self.hyp['shear'],
-                                  border=border)  # border to remove
+                                  border=self.mosaic_border)  # border to remove
 
     return img4, labels4