import gradio as gr import torch import os import json import random import sys import logging import warnings import re import spaces from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from transformers import AutoModel, AutoTokenizer from dataclasses import dataclass sys.path.append(os.path.dirname(os.path.abspath(__file__))) from diffusers import ZImagePipeline from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel from pe import prompt_template # ==================== Environment Variables ================================== MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") # ============================================================================= os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) RESOLUTION_SET = [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 (16:9 )", "720x1280 (9:16 )", "1344x576 (21:9 )", "576x1344 (9:21 )", ] RES_CHOICES = { "1024": [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", ], "1280": [ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", "1024x1536 ( 2:3 )", "1600x900 ( 16:9 )", "900x1600 ( 9:16 )", "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", ], } EXAMPLE_PROMPTS = [ ['A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words \'Zao-Xiang * East Beauty & West Fashion * Z-Image\'. Directly beneath this, in a larger, elegant black serif font, is the word \'SHOW & SHARE CREATIVITY WITH THE WORLD\'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'], ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"], ] def get_resolution(resolution): match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) if match: return int(match.group(1)), int(match.group(2)) return 1024, 1024 def load_models(model_path, enable_compile=False, attention_backend="native"): print(f"Loading models from {model_path}...") use_auth_token = HF_TOKEN if HF_TOKEN else True if not os.path.exists(model_path): vae = AutoencoderKL.from_pretrained( f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token ) text_encoder = AutoModel.from_pretrained( f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token ).eval() tokenizer = AutoTokenizer.from_pretrained( f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token ) else: vae = AutoencoderKL.from_pretrained( os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda" ) text_encoder = AutoModel.from_pretrained( os.path.join(model_path, "text_encoder"), torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) tokenizer.padding_side = "left" if enable_compile: print("Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False pipe = ZImagePipeline( scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None ) if enable_compile: pipe.vae.disable_tiling() if not os.path.exists(model_path): transformer = ZImageTransformer2DModel.from_pretrained( f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token ).to("cuda", torch.bfloat16) else: transformer = ZImageTransformer2DModel.from_pretrained( os.path.join(model_path, "transformer") ).to("cuda", torch.bfloat16) pipe.transformer = transformer pipe.transformer.set_attention_backend(attention_backend) if enable_compile: print("Compiling transformer...") pipe.transformer = torch.compile( pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False ) pipe.to("cuda", torch.bfloat16) return pipe def generate_image( pipe, prompt, resolution="1024x1024", seed=-1, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, ): height, width = get_resolution(resolution) if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() print(f"Using seed: {seed}") generator = torch.Generator("cuda").manual_seed(seed) scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=1000, shift=shift ) pipe.scheduler = scheduler image = pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=max_sequence_length, ).images[0] return image def warmup_model(pipe, resolutions): print("Starting warmup phase...") dummy_prompt = "warmup" for res_str in resolutions: print(f"Warming up for resolution: {res_str}") try: for i in range(3): generate_image( pipe, prompt=dummy_prompt, resolution=res_str, num_inference_steps=9, guidance_scale=0.0, seed=42 + i ) except Exception as e: print(f"Warmup failed for {res_str}: {e}") print("Warmup completed.") # ==================== Prompt Expander ==================== @dataclass class PromptOutput: status: bool prompt: str seed: int system_prompt: str message: str class PromptExpander: def __init__(self, backend="api", **kwargs): self.backend = backend def decide_system_prompt(self, template_name=None): return prompt_template class APIPromptExpander(PromptExpander): def __init__(self, api_config=None, **kwargs): super().__init__(backend="api", **kwargs) self.api_config = api_config or {} self.client = self._init_api_client() def _init_api_client(self): try: from openai import OpenAI api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") if not api_key: print("Warning: DASHSCOPE_API_KEY not found.") return None return OpenAI(api_key=api_key, base_url=base_url) except ImportError: print("Please install openai: pip install openai") return None except Exception as e: print(f"Failed to initialize API client: {e}") return None def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): return self.extend(prompt, system_prompt, seed, **kwargs) def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): if self.client is None: return PromptOutput(False, "", seed, system_prompt, "API client not initialized") if system_prompt is None: system_prompt = self.decide_system_prompt() if "{prompt}" in system_prompt: system_prompt = system_prompt.format(prompt=prompt) prompt = " " try: model = self.api_config.get("model", "qwen3-max-preview") response = self.client.chat.completions.create( model=model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ], temperature=0.7, top_p=0.8, ) content = response.choices[0].message.content json_start = content.find("```json") if json_start != -1: json_end = content.find("```", json_start + 7) try: json_str = content[json_start + 7 : json_end].strip() data = json.loads(json_str) expanded_prompt = data.get("revised_prompt", content) except: expanded_prompt = content else: expanded_prompt = content return PromptOutput( status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content ) except Exception as e: return PromptOutput(False, "", seed, system_prompt, str(e)) def create_prompt_expander(backend="api", **kwargs): if backend == "api": return APIPromptExpander(**kwargs) raise ValueError("Only 'api' backend is supported.") pipe = None prompt_expander = None def init_app(): global pipe, prompt_expander try: pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") if ENABLE_WARMUP: all_resolutions = [] for cat in RES_CHOICES.values(): all_resolutions.extend(cat) warmup_model(pipe, all_resolutions) except Exception as e: print(f"Error loading model: {e}") pipe = None try: prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"}) print("Prompt expander initialized.") except Exception as e: print(f"Error initializing prompt expander: {e}") prompt_expander = None def prompt_enhance(prompt, enable_enhance): if not enable_enhance or not prompt_expander: return prompt, "Enhancement disabled or not available." if not prompt.strip(): return "", "Please enter a prompt." try: result = prompt_expander(prompt) if result.status: return result.prompt, result.message else: return prompt, f"Enhancement failed: {result.message}" except Exception as e: return prompt, f"Error: {str(e)}" @spaces.GPU def generate(prompt, resolution, seed, steps, shift, enhance, gallery_images): if pipe is None: raise gr.Error("Model not loaded.") final_prompt = prompt if enhance: final_prompt, _ = prompt_enhance(prompt, True) print(f"Enhanced prompt: {final_prompt}") if seed == -1: seed = random.randint(0, 1000000) try: resolution_str = resolution.split(" ")[0] except: resolution_str = "1024x1024" image = generate_image( pipe=pipe, prompt=final_prompt, resolution=resolution_str, seed=seed, guidance_scale=0.0, num_inference_steps=steps, shift=shift ) if gallery_images is None: gallery_images = [] gallery_images.append(image) return gallery_images, str(seed) init_app() with gr.Blocks(title="Z-Image Demo") as demo: gr.Markdown("# Z-Image Generation Demo") with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...") # PE components (Temporarily disabled) # with gr.Row(): # enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False) # enhance_btn = gr.Button("Enhance Only") with gr.Row(): choices = [int(k) for k in RES_CHOICES.keys()] res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category") initial_res_choices = RES_CHOICES["1024"] resolution = gr.Dropdown( value=initial_res_choices[0], choices=initial_res_choices, label="Resolution" ) seed = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random") with gr.Row(): steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=9, step=1) shift = gr.Slider(label="Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1) generate_btn = gr.Button("Generate", variant="primary") # Example prompts gr.Markdown("### 📝 Example Prompts") gr.Examples( examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None ) with gr.Column(scale=1): output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2, height=600, object_fit="contain") used_seed = gr.Textbox(label="Seed Used", interactive=False) def update_res_choices(_res_cat): if str(_res_cat) in RES_CHOICES: res_choices = RES_CHOICES[str(_res_cat)] else: res_choices = RES_CHOICES["1024"] return gr.update(value=res_choices[0], choices=res_choices) res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution) # PE enhancement button (Temporarily disabled) # enhance_btn.click( # prompt_enhance, # inputs=[prompt_input, enable_enhance], # outputs=[prompt_input, final_prompt_output] # ) # Dummy enable_enhance variable set to False enable_enhance = gr.State(value=False) generate_btn.click( generate, inputs=[ prompt_input, resolution, seed, steps, shift, enable_enhance, output_gallery ], outputs=[output_gallery, used_seed] ) if __name__ == "__main__": demo.launch()