diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..ad36396
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,24 @@
+build:
+ cuda: "11.3"
+ gpu: true
+ python_version: "3.9"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "numpy==1.21.1"
+ - "ipython==7.21.0"
+ - "addict==2.4.0"
+ - "future==0.18.2"
+ - "lmdb==1.3.0"
+ - "opencv-python==4.5.5.64"
+ - "Pillow==9.1.0"
+ - "pyyaml==6.0"
+ - "torch==1.11.0"
+ - "torchvision==0.12.0"
+ - "tqdm==4.64.0"
+ - "scipy==1.8.0"
+ - "scikit-image==0.19.2"
+ - "matplotlib==3.5.1"
+
+predict: "predict.py:Predictor"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..3a37c7b
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,137 @@
+import torch
+import numpy as np
+import cv2
+import tempfile
+import matplotlib.pyplot as plt
+from cog import BasePredictor, Path, Input, BaseModel
+
+from basicsr.models import create_model
+from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
+from basicsr.utils.options import parse
+
+
+class Predictor(BasePredictor):
+ def setup(self):
+ opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml"
+ opt_denoise = parse(opt_path_denoise, is_train=False)
+ opt_denoise["dist"] = False
+
+ opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml"
+ opt_deblur = parse(opt_path_deblur, is_train=False)
+ opt_deblur["dist"] = False
+
+ opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml"
+ opt_stereo = parse(opt_path_stereo, is_train=False)
+ opt_stereo["dist"] = False
+
+ self.models = {
+ "Image Denoising": create_model(opt_denoise),
+ "Image Debluring": create_model(opt_deblur),
+ "Stereo Image Super-Resolution": create_model(opt_stereo),
+ }
+
+ def predict(
+ self,
+ task_type: str = Input(
+ choices=[
+ "Image Denoising",
+ "Image Debluring",
+ "Stereo Image Super-Resolution",
+ ],
+ default="Image Debluring",
+ description="Choose task type.",
+ ),
+ image: Path = Input(
+ description="Input image. Stereo Image Super-Resolution, upload the left image here.",
+ ),
+ image_r: Path = Input(
+ default=None,
+ description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo"
+ " Image Super-Resolution task.",
+ ),
+ ) -> Path:
+
+ out_path = Path(tempfile.mkdtemp()) / "output.png"
+
+ model = self.models[task_type]
+ if task_type == "Stereo Image Super-Resolution":
+ assert image_r is not None, (
+ "Please provide both left and right input image for "
+ "Stereo Image Super-Resolution task."
+ )
+
+ img_l = imread(str(image))
+ inp_l = img2tensor(img_l)
+ img_r = imread(str(image_r))
+ inp_r = img2tensor(img_r)
+ stereo_image_inference(model, inp_l, inp_r, str(out_path))
+
+ else:
+
+ img_input = imread(str(image))
+ inp = img2tensor(img_input)
+ out_path = Path(tempfile.mkdtemp()) / "output.png"
+ single_image_inference(model, inp, str(out_path))
+
+ return out_path
+
+
+def imread(img_path):
+ img = cv2.imread(img_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+def img2tensor(img, bgr2rgb=False, float32=True):
+ img = img.astype(np.float32) / 255.0
+ return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)
+
+
+def single_image_inference(model, img, save_path):
+ model.feed_data(data={"lq": img.unsqueeze(dim=0)})
+
+ if model.opt["val"].get("grids", False):
+ model.grids()
+
+ model.test()
+
+ if model.opt["val"].get("grids", False):
+ model.grids_inverse()
+
+ visuals = model.get_current_visuals()
+ sr_img = tensor2img([visuals["result"]])
+ imwrite(sr_img, save_path)
+
+
+def stereo_image_inference(model, img_l, img_r, out_path):
+ img = torch.cat([img_l, img_r], dim=0)
+ model.feed_data(data={"lq": img.unsqueeze(dim=0)})
+
+ if model.opt["val"].get("grids", False):
+ model.grids()
+
+ model.test()
+
+ if model.opt["val"].get("grids", False):
+ model.grids_inverse()
+
+ visuals = model.get_current_visuals()
+ img_L = visuals["result"][:, :3]
+ img_R = visuals["result"][:, 3:]
+ img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False)
+
+ # save_stereo_image
+ h, w = img_L.shape[:2]
+ fig = plt.figure(figsize=(w // 40, h // 40))
+ ax1 = fig.add_subplot(2, 1, 1)
+ plt.title("NAFSSR output (Left)", fontsize=14)
+ ax1.axis("off")
+ ax1.imshow(img_L)
+
+ ax2 = fig.add_subplot(2, 1, 2)
+ plt.title("NAFSSR output (Right)", fontsize=14)
+ ax2.axis("off")
+ ax2.imshow(img_R)
+
+ plt.subplots_adjust(hspace=0.08)
+ plt.savefig(str(out_path), bbox_inches="tight", dpi=600)
diff --git a/readme.md b/readme.md
index 5b462b6..1da1848 100644
--- a/readme.md
+++ b/readme.md
@@ -49,6 +49,8 @@ python setup.py develop --no_cuda_ext
* Image Deblur Colab Demo: [
](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing)
* Stereo Image Super-Resolution Colab Demo: [
](https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing)
+Try the web demo with all three tasks here: [](https://replicate.com/megvii-research/nafnet)
+
* Single Image Inference Demo:
* Image Denoise:
```