import os import sys import json import torch from model import pipeline # Initialize the model model = pipeline() def run(prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=42): """Run the model with the given parameters.""" return model( prompt=prompt, negative_prompt=negative_prompt, num_paths=int(num_paths), guidance_scale=float(guidance_scale), seed=int(seed) ) def parse_args(): """Parse command line arguments.""" if len(sys.argv) > 1: # Command line arguments prompt = sys.argv[1] negative_prompt = sys.argv[2] if len(sys.argv) > 2 else "" num_paths = int(sys.argv[3]) if len(sys.argv) > 3 else 96 guidance_scale = float(sys.argv[4]) if len(sys.argv) > 4 else 7.5 seed = int(sys.argv[5]) if len(sys.argv) > 5 else 42 else: # Read from stdin (for API) data = json.loads(sys.stdin.read()) prompt = data.get("prompt", "") negative_prompt = data.get("negative_prompt", "") num_paths = int(data.get("num_paths", 96)) guidance_scale = float(data.get("guidance_scale", 7.5)) seed = int(data.get("seed", 42)) return prompt, negative_prompt, num_paths, guidance_scale, seed if __name__ == "__main__": # Parse arguments prompt, negative_prompt, num_paths, guidance_scale, seed = parse_args() # Run the model result = run(prompt, negative_prompt, num_paths, guidance_scale, seed) # Print the result as JSON print(json.dumps(result))