Spaces:
Build error
Build error
| import os | |
| import torch | |
| from gfpgan import GFPGANer | |
| from tqdm import tqdm | |
| from src.utils.videoio import load_video_to_cv2 | |
| def enhancer(images, method='gfpgan', bg_upsampler='realesrgan'): | |
| print('face enhancer....') | |
| if os.path.isfile(images): # handle video to images | |
| images = load_video_to_cv2(images) | |
| # ------------------------ set up GFPGAN restorer ------------------------ | |
| if method == 'gfpgan': | |
| arch = 'clean' | |
| channel_multiplier = 2 | |
| model_name = 'GFPGANv1.4' | |
| url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' | |
| elif method == 'RestoreFormer': | |
| arch = 'RestoreFormer' | |
| channel_multiplier = 2 | |
| model_name = 'RestoreFormer' | |
| url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' | |
| elif method == 'codeformer': | |
| arch = 'CodeFormer' | |
| channel_multiplier = 2 | |
| model_name = 'CodeFormer' | |
| url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' | |
| else: | |
| raise ValueError(f'Wrong model version {method}.') | |
| # ------------------------ set up background upsampler ------------------------ | |
| if bg_upsampler == 'realesrgan': | |
| if not torch.cuda.is_available(): # CPU | |
| import warnings | |
| warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' | |
| 'If you really want to use it, please modify the corresponding codes.') | |
| bg_upsampler = None | |
| else: | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| bg_upsampler = RealESRGANer( | |
| scale=2, | |
| model_path='/apdcephfs/private_shadowcun/SadTalker/gfpgan/weights/RealESRGAN_x2plus.pth', | |
| model=model, | |
| tile=400, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=True) # need to set False in CPU mode | |
| else: | |
| bg_upsampler = None | |
| # determine model paths | |
| model_path = os.path.join('gfpgan/weights', model_name + '.pth') | |
| if not os.path.isfile(model_path): | |
| model_path = os.path.join('checkpoints', model_name + '.pth') | |
| if not os.path.isfile(model_path): | |
| # download pre-trained models from url | |
| model_path = url | |
| restorer = GFPGANer( | |
| model_path=model_path, | |
| upscale=2, | |
| arch=arch, | |
| channel_multiplier=channel_multiplier, | |
| bg_upsampler=bg_upsampler) | |
| # ------------------------ restore ------------------------ | |
| restored_img = [] | |
| for idx in tqdm(range(len(images)), 'Face Enhancer:'): | |
| # restore faces and background if necessary | |
| cropped_faces, restored_faces, r_img = restorer.enhance( | |
| images[idx], | |
| has_aligned=False, | |
| only_center_face=False, | |
| paste_back=True) | |
| restored_img += [r_img] | |
| return restored_img |