replicate demo
parent
fea15dc39b
commit
f013110f12
|
@ -0,0 +1,24 @@
|
|||
build:
|
||||
cuda: "11.3"
|
||||
gpu: true
|
||||
python_version: "3.9"
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_packages:
|
||||
- "numpy==1.21.1"
|
||||
- "ipython==7.21.0"
|
||||
- "addict==2.4.0"
|
||||
- "future==0.18.2"
|
||||
- "lmdb==1.3.0"
|
||||
- "opencv-python==4.5.5.64"
|
||||
- "Pillow==9.1.0"
|
||||
- "pyyaml==6.0"
|
||||
- "torch==1.11.0"
|
||||
- "torchvision==0.12.0"
|
||||
- "tqdm==4.64.0"
|
||||
- "scipy==1.8.0"
|
||||
- "scikit-image==0.19.2"
|
||||
- "matplotlib==3.5.1"
|
||||
|
||||
predict: "predict.py:Predictor"
|
|
@ -0,0 +1,137 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
import tempfile
|
||||
import matplotlib.pyplot as plt
|
||||
from cog import BasePredictor, Path, Input, BaseModel
|
||||
|
||||
from basicsr.models import create_model
|
||||
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
|
||||
from basicsr.utils.options import parse
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml"
|
||||
opt_denoise = parse(opt_path_denoise, is_train=False)
|
||||
opt_denoise["dist"] = False
|
||||
|
||||
opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml"
|
||||
opt_deblur = parse(opt_path_deblur, is_train=False)
|
||||
opt_deblur["dist"] = False
|
||||
|
||||
opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml"
|
||||
opt_stereo = parse(opt_path_stereo, is_train=False)
|
||||
opt_stereo["dist"] = False
|
||||
|
||||
self.models = {
|
||||
"Image Denoising": create_model(opt_denoise),
|
||||
"Image Debluring": create_model(opt_deblur),
|
||||
"Stereo Image Super-Resolution": create_model(opt_stereo),
|
||||
}
|
||||
|
||||
def predict(
|
||||
self,
|
||||
task_type: str = Input(
|
||||
choices=[
|
||||
"Image Denoising",
|
||||
"Image Debluring",
|
||||
"Stereo Image Super-Resolution",
|
||||
],
|
||||
default="Image Debluring",
|
||||
description="Choose task type.",
|
||||
),
|
||||
image: Path = Input(
|
||||
description="Input image. Stereo Image Super-Resolution, upload the left image here.",
|
||||
),
|
||||
image_r: Path = Input(
|
||||
default=None,
|
||||
description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo"
|
||||
" Image Super-Resolution task.",
|
||||
),
|
||||
) -> Path:
|
||||
|
||||
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
||||
|
||||
model = self.models[task_type]
|
||||
if task_type == "Stereo Image Super-Resolution":
|
||||
assert image_r is not None, (
|
||||
"Please provide both left and right input image for "
|
||||
"Stereo Image Super-Resolution task."
|
||||
)
|
||||
|
||||
img_l = imread(str(image))
|
||||
inp_l = img2tensor(img_l)
|
||||
img_r = imread(str(image_r))
|
||||
inp_r = img2tensor(img_r)
|
||||
stereo_image_inference(model, inp_l, inp_r, str(out_path))
|
||||
|
||||
else:
|
||||
|
||||
img_input = imread(str(image))
|
||||
inp = img2tensor(img_input)
|
||||
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
||||
single_image_inference(model, inp, str(out_path))
|
||||
|
||||
return out_path
|
||||
|
||||
|
||||
def imread(img_path):
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def img2tensor(img, bgr2rgb=False, float32=True):
|
||||
img = img.astype(np.float32) / 255.0
|
||||
return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)
|
||||
|
||||
|
||||
def single_image_inference(model, img, save_path):
|
||||
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
|
||||
|
||||
if model.opt["val"].get("grids", False):
|
||||
model.grids()
|
||||
|
||||
model.test()
|
||||
|
||||
if model.opt["val"].get("grids", False):
|
||||
model.grids_inverse()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
sr_img = tensor2img([visuals["result"]])
|
||||
imwrite(sr_img, save_path)
|
||||
|
||||
|
||||
def stereo_image_inference(model, img_l, img_r, out_path):
|
||||
img = torch.cat([img_l, img_r], dim=0)
|
||||
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
|
||||
|
||||
if model.opt["val"].get("grids", False):
|
||||
model.grids()
|
||||
|
||||
model.test()
|
||||
|
||||
if model.opt["val"].get("grids", False):
|
||||
model.grids_inverse()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
img_L = visuals["result"][:, :3]
|
||||
img_R = visuals["result"][:, 3:]
|
||||
img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False)
|
||||
|
||||
# save_stereo_image
|
||||
h, w = img_L.shape[:2]
|
||||
fig = plt.figure(figsize=(w // 40, h // 40))
|
||||
ax1 = fig.add_subplot(2, 1, 1)
|
||||
plt.title("NAFSSR output (Left)", fontsize=14)
|
||||
ax1.axis("off")
|
||||
ax1.imshow(img_L)
|
||||
|
||||
ax2 = fig.add_subplot(2, 1, 2)
|
||||
plt.title("NAFSSR output (Right)", fontsize=14)
|
||||
ax2.axis("off")
|
||||
ax2.imshow(img_R)
|
||||
|
||||
plt.subplots_adjust(hspace=0.08)
|
||||
plt.savefig(str(out_path), bbox_inches="tight", dpi=600)
|
|
@ -49,6 +49,8 @@ python setup.py develop --no_cuda_ext
|
|||
* Image Deblur Colab Demo: [<a href="https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing)
|
||||
* Stereo Image Super-Resolution Colab Demo: [<a href="https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing)
|
||||
|
||||
Try the web demo with all three tasks here: [](https://replicate.com/megvii-research/nafnet)
|
||||
|
||||
* Single Image Inference Demo:
|
||||
* Image Denoise:
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue