Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from gfpgan import GFPGANer | |
| from tqdm import tqdm | |
| import cv2 | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| import warnings | |
| from enum import Enum | |
| class EnhancementMethod(str, Enum): | |
| gfpgan = "gfpgan" | |
| RestoreFormer = "RestoreFormer" | |
| codeformer = "codeformer" | |
| realesrgan = "realesrgan" | |
| class Enhancer: | |
| def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2): | |
| self.method = method | |
| self.background_enhancement = background_enhancement | |
| self.upscale = upscale | |
| self.bg_upsampler = None | |
| self.realesrgan_enhancer = None | |
| if self.method != EnhancementMethod.realesrgan: | |
| self.setup_face_enhancer() | |
| if self.background_enhancement: | |
| self.setup_background_enhancer() | |
| else: | |
| self.setup_realesrgan_enhancer() | |
| def setup_background_enhancer(self): | |
| if not torch.cuda.is_available(): | |
| warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.') | |
| return | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale) | |
| model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth' | |
| self.bg_upsampler = RealESRGANer( | |
| scale=self.upscale, | |
| model_path=model_path, | |
| model=model, | |
| tile=400, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=True) | |
| def setup_realesrgan_enhancer(self): | |
| if not torch.cuda.is_available(): | |
| raise ValueError('CUDA is not available for RealESRGAN') | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale) | |
| model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth' | |
| self.realesrgan_enhancer = RealESRGANer( | |
| scale=self.upscale, | |
| model_path=model_path, | |
| model=model, | |
| tile=400, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=True) | |
| def setup_face_enhancer(self): | |
| model_configs = { | |
| EnhancementMethod.gfpgan: { | |
| 'arch': 'clean', | |
| 'channel_multiplier': 2, | |
| 'model_name': 'GFPGANv1.4', | |
| 'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth' | |
| }, | |
| EnhancementMethod.RestoreFormer: { | |
| 'arch': 'RestoreFormer', | |
| 'channel_multiplier': 2, | |
| 'model_name': 'RestoreFormer', | |
| 'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' | |
| }, | |
| EnhancementMethod.codeformer: { | |
| 'arch': 'CodeFormer', | |
| 'channel_multiplier': 2, | |
| 'model_name': 'CodeFormer', | |
| 'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth' | |
| } | |
| } | |
| config = model_configs.get(self.method) | |
| if not config: | |
| raise ValueError(f'Wrong model version {self.method}') | |
| model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth') | |
| if not os.path.isfile(model_path): | |
| model_path = os.path.join('checkpoints', config['model_name'] + '.pth') | |
| if not os.path.isfile(model_path): | |
| model_path = config['url'] | |
| self.face_enhancer = GFPGANer( | |
| model_path=model_path, | |
| upscale=self.upscale, | |
| arch=config['arch'], | |
| channel_multiplier=config['channel_multiplier'], | |
| bg_upsampler=self.bg_upsampler) | |
| def check_image_resolution(self, image): | |
| height, width, _ = image.shape | |
| return width, height | |
| async def enhance(self, image): | |
| img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| width, height = self.check_image_resolution(img) | |
| if self.method == EnhancementMethod.realesrgan: | |
| enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale) | |
| else: | |
| _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance, | |
| img, | |
| has_aligned=False, | |
| only_center_face=False, | |
| paste_back=True) | |
| enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB) | |
| enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img) | |
| return enhanced_img, (width, height), (enhanced_width, enhanced_height) |