Add EXIF rotation to YOLOv5 Hub inference (#3852)
* rotating an image according to its exif tag * Update common.py * Update datasets.py * Update datasets.py faster * delete extraneous gpg file * Update common.py Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/3875/head
parent
4717a3b038
commit
831773f5a2
|
@ -1,9 +1,9 @@
|
|||
# YOLOv5 common modules
|
||||
|
||||
import math
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from PIL import Image
|
||||
from torch.cuda import amp
|
||||
|
||||
from utils.datasets import letterbox
|
||||
from utils.datasets import exif_transpose, letterbox
|
||||
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
|
||||
from utils.plots import colors, plot_one_box
|
||||
from utils.torch_utils import time_synchronized
|
||||
|
@ -252,9 +252,10 @@ class AutoShape(nn.Module):
|
|||
for i, im in enumerate(imgs):
|
||||
f = f'image{i}' # filename
|
||||
if isinstance(im, str): # filename or uri
|
||||
im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
|
||||
im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im
|
||||
im = np.asarray(exif_transpose(im))
|
||||
elif isinstance(im, Image.Image): # PIL Image
|
||||
im, f = np.asarray(im), getattr(im, 'filename', f) or f
|
||||
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename') or f
|
||||
files.append(Path(f).with_suffix('.jpg').name)
|
||||
if im.shape[0] < 5: # image in CHW
|
||||
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
||||
|
|
|
@ -64,6 +64,32 @@ def exif_size(img):
|
|||
return s
|
||||
|
||||
|
||||
def exif_transpose(image):
|
||||
"""
|
||||
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
|
||||
From https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
|
||||
|
||||
:param image: The image to transpose.
|
||||
:return: An image.
|
||||
"""
|
||||
exif = image.getexif()
|
||||
orientation = exif.get(0x0112, 1) # default 1
|
||||
if orientation > 1:
|
||||
method = {2: Image.FLIP_LEFT_RIGHT,
|
||||
3: Image.ROTATE_180,
|
||||
4: Image.FLIP_TOP_BOTTOM,
|
||||
5: Image.TRANSPOSE,
|
||||
6: Image.ROTATE_270,
|
||||
7: Image.TRANSVERSE,
|
||||
8: Image.ROTATE_90,
|
||||
}.get(orientation)
|
||||
if method is not None:
|
||||
image = image.transpose(method)
|
||||
del exif[0x0112]
|
||||
image.info["exif"] = exif.tobytes()
|
||||
return image
|
||||
|
||||
|
||||
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
||||
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
|
||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
||||
|
|
Loading…
Reference in New Issue