From f013110f1213a5fe1acf984ef34d990fb9aecadf Mon Sep 17 00:00:00 2001 From: Chenxi Date: Wed, 20 Apr 2022 11:03:59 +0100 Subject: [PATCH] replicate demo --- cog.yaml | 24 ++++++++++ predict.py | 137 +++++++++++++++++++++++++++++++++++++++++++++++++++++ readme.md | 2 + 3 files changed, 163 insertions(+) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..ad36396 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,24 @@ +build: + cuda: "11.3" + gpu: true + python_version: "3.9" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "numpy==1.21.1" + - "ipython==7.21.0" + - "addict==2.4.0" + - "future==0.18.2" + - "lmdb==1.3.0" + - "opencv-python==4.5.5.64" + - "Pillow==9.1.0" + - "pyyaml==6.0" + - "torch==1.11.0" + - "torchvision==0.12.0" + - "tqdm==4.64.0" + - "scipy==1.8.0" + - "scikit-image==0.19.2" + - "matplotlib==3.5.1" + +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..3a37c7b --- /dev/null +++ b/predict.py @@ -0,0 +1,137 @@ +import torch +import numpy as np +import cv2 +import tempfile +import matplotlib.pyplot as plt +from cog import BasePredictor, Path, Input, BaseModel + +from basicsr.models import create_model +from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite +from basicsr.utils.options import parse + + +class Predictor(BasePredictor): + def setup(self): + opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml" + opt_denoise = parse(opt_path_denoise, is_train=False) + opt_denoise["dist"] = False + + opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml" + opt_deblur = parse(opt_path_deblur, is_train=False) + opt_deblur["dist"] = False + + opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml" + opt_stereo = parse(opt_path_stereo, is_train=False) + opt_stereo["dist"] = False + + self.models = { + "Image Denoising": create_model(opt_denoise), + "Image Debluring": create_model(opt_deblur), + "Stereo Image Super-Resolution": create_model(opt_stereo), + } + + def predict( + self, + task_type: str = Input( + choices=[ + "Image Denoising", + "Image Debluring", + "Stereo Image Super-Resolution", + ], + default="Image Debluring", + description="Choose task type.", + ), + image: Path = Input( + description="Input image. Stereo Image Super-Resolution, upload the left image here.", + ), + image_r: Path = Input( + default=None, + description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo" + " Image Super-Resolution task.", + ), + ) -> Path: + + out_path = Path(tempfile.mkdtemp()) / "output.png" + + model = self.models[task_type] + if task_type == "Stereo Image Super-Resolution": + assert image_r is not None, ( + "Please provide both left and right input image for " + "Stereo Image Super-Resolution task." + ) + + img_l = imread(str(image)) + inp_l = img2tensor(img_l) + img_r = imread(str(image_r)) + inp_r = img2tensor(img_r) + stereo_image_inference(model, inp_l, inp_r, str(out_path)) + + else: + + img_input = imread(str(image)) + inp = img2tensor(img_input) + out_path = Path(tempfile.mkdtemp()) / "output.png" + single_image_inference(model, inp, str(out_path)) + + return out_path + + +def imread(img_path): + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def img2tensor(img, bgr2rgb=False, float32=True): + img = img.astype(np.float32) / 255.0 + return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) + + +def single_image_inference(model, img, save_path): + model.feed_data(data={"lq": img.unsqueeze(dim=0)}) + + if model.opt["val"].get("grids", False): + model.grids() + + model.test() + + if model.opt["val"].get("grids", False): + model.grids_inverse() + + visuals = model.get_current_visuals() + sr_img = tensor2img([visuals["result"]]) + imwrite(sr_img, save_path) + + +def stereo_image_inference(model, img_l, img_r, out_path): + img = torch.cat([img_l, img_r], dim=0) + model.feed_data(data={"lq": img.unsqueeze(dim=0)}) + + if model.opt["val"].get("grids", False): + model.grids() + + model.test() + + if model.opt["val"].get("grids", False): + model.grids_inverse() + + visuals = model.get_current_visuals() + img_L = visuals["result"][:, :3] + img_R = visuals["result"][:, 3:] + img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False) + + # save_stereo_image + h, w = img_L.shape[:2] + fig = plt.figure(figsize=(w // 40, h // 40)) + ax1 = fig.add_subplot(2, 1, 1) + plt.title("NAFSSR output (Left)", fontsize=14) + ax1.axis("off") + ax1.imshow(img_L) + + ax2 = fig.add_subplot(2, 1, 2) + plt.title("NAFSSR output (Right)", fontsize=14) + ax2.axis("off") + ax2.imshow(img_R) + + plt.subplots_adjust(hspace=0.08) + plt.savefig(str(out_path), bbox_inches="tight", dpi=600) diff --git a/readme.md b/readme.md index 5b462b6..1da1848 100644 --- a/readme.md +++ b/readme.md @@ -49,6 +49,8 @@ python setup.py develop --no_cuda_ext * Image Deblur Colab Demo: [google colab logo](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing) * Stereo Image Super-Resolution Colab Demo: [google colab logo](https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing) +Try the web demo with all three tasks here: [![Replicate](https://replicate.com/megvii-research/nafnet/badge)](https://replicate.com/megvii-research/nafnet) + * Single Image Inference Demo: * Image Denoise: ```