Spaces:
Runtime error
Runtime error
update
Browse files- infer_api.py +27 -31
infer_api.py
CHANGED
|
@@ -108,34 +108,32 @@ def set_seed(seed):
|
|
| 108 |
torch.manual_seed(seed)
|
| 109 |
torch.cuda.manual_seed_all(seed)
|
| 110 |
|
| 111 |
-
class BkgRemover:
|
| 112 |
-
def __init__(self, force_cpu: Optional[bool] = True):
|
| 113 |
-
session_infer_path = hf_hub_download(
|
| 114 |
-
repo_id="skytnt/anime-seg", filename="isnetis.onnx",
|
| 115 |
-
)
|
| 116 |
-
providers: list[str] = ["CPUExecutionProvider"]
|
| 117 |
-
if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
|
| 118 |
-
providers = ["CUDAExecutionProvider"]
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
def process_image(image, totensor, width, height):
|
|
@@ -168,7 +166,7 @@ def process_image(image, totensor, width, height):
|
|
| 168 |
|
| 169 |
@spaces.GPU
|
| 170 |
@torch.no_grad()
|
| 171 |
-
def inference(validation_pipeline,
|
| 172 |
text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
|
| 173 |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
|
| 174 |
set_seed(seed)
|
|
@@ -186,7 +184,7 @@ def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extrac
|
|
| 186 |
B = 1
|
| 187 |
if input_image.mode != "RGBA":
|
| 188 |
# remove background
|
| 189 |
-
input_image =
|
| 190 |
imgs_in = process_image(input_image, totensor, val_width, val_height)
|
| 191 |
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
|
| 192 |
|
|
@@ -869,11 +867,9 @@ class InferCanonicalAPI:
|
|
| 869 |
)
|
| 870 |
self.validation_pipeline.set_progress_bar_config(disable=True)
|
| 871 |
|
| 872 |
-
self.bkg_remover = BkgRemover()
|
| 873 |
-
|
| 874 |
def canonicalize(self, image, seed):
|
| 875 |
return inference(
|
| 876 |
-
self.validation_pipeline,
|
| 877 |
self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
|
| 878 |
use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
|
| 879 |
)
|
|
|
|
| 108 |
torch.manual_seed(seed)
|
| 109 |
torch.cuda.manual_seed_all(seed)
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
session_infer_path = hf_hub_download(
|
| 113 |
+
repo_id="skytnt/anime-seg", filename="isnetis.onnx",
|
| 114 |
+
)
|
| 115 |
+
providers: list[str] = ["CPUExecutionProvider"]
|
| 116 |
+
if "CUDAExecutionProvider" in rt.get_available_providers():
|
| 117 |
+
providers = ["CUDAExecutionProvider"]
|
| 118 |
|
| 119 |
+
bkg_remover_session_infer = rt.InferenceSession(
|
| 120 |
+
session_infer_path, providers=providers,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
@spaces.GPU
|
| 124 |
+
def remove_background(
|
| 125 |
+
img: np.ndarray,
|
| 126 |
+
alpha_min: float,
|
| 127 |
+
alpha_max: float,
|
| 128 |
+
) -> list:
|
| 129 |
+
img = np.array(img)
|
| 130 |
+
mask = get_mask(bkg_remover_session_infer, img)
|
| 131 |
+
mask[mask < alpha_min] = 0.0
|
| 132 |
+
mask[mask > alpha_max] = 1.0
|
| 133 |
+
img_after = (mask * img).astype(np.uint8)
|
| 134 |
+
mask = (mask * SCALE).astype(np.uint8)
|
| 135 |
+
img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
|
| 136 |
+
return Image.fromarray(img_after)
|
| 137 |
|
| 138 |
|
| 139 |
def process_image(image, totensor, width, height):
|
|
|
|
| 166 |
|
| 167 |
@spaces.GPU
|
| 168 |
@torch.no_grad()
|
| 169 |
+
def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
|
| 170 |
text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
|
| 171 |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
|
| 172 |
set_seed(seed)
|
|
|
|
| 184 |
B = 1
|
| 185 |
if input_image.mode != "RGBA":
|
| 186 |
# remove background
|
| 187 |
+
input_image = remove_background(input_image, 0.1, 0.9)
|
| 188 |
imgs_in = process_image(input_image, totensor, val_width, val_height)
|
| 189 |
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
|
| 190 |
|
|
|
|
| 867 |
)
|
| 868 |
self.validation_pipeline.set_progress_bar_config(disable=True)
|
| 869 |
|
|
|
|
|
|
|
| 870 |
def canonicalize(self, image, seed):
|
| 871 |
return inference(
|
| 872 |
+
self.validation_pipeline, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
|
| 873 |
self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
|
| 874 |
use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
|
| 875 |
)
|