import os import sys import json import torch from pathlib import Path # Determine which model we're running based on the repository name def get_model_type(): # Default to diffsketcher if we can't determine model_type = "diffsketcher" # Check if we're in a Hugging Face environment if os.path.exists("/repository"): repo_path = Path("/repository") # Try to determine model type from repository name if os.path.exists("/repository/.git"): try: with open("/repository/.git/config", "r") as f: config = f.read() if "svgdreamer" in config.lower(): model_type = "svgdreamer" elif "diffsketcher_edit" in config.lower() or "diffsketcher-edit" in config.lower(): model_type = "diffsketcher_edit" except: pass print(f"Detected model type: {model_type}") return model_type # Import the appropriate handler based on model type def import_handler(): model_type = get_model_type() if model_type == "svgdreamer": from svgdreamer_handler import SVGDreamerHandler return SVGDreamerHandler() elif model_type == "diffsketcher_edit": from diffsketcher_edit_handler import DiffSketcherEditHandler return DiffSketcherEditHandler() else: from diffsketcher_handler import DiffSketcherHandler return DiffSketcherHandler() # Initialize the handler handler = import_handler() handler.initialize(None) # Define the inference function for the API def inference(model_inputs): global handler return handler.handle(model_inputs, None) # This is used when running locally if __name__ == "__main__": # Test the handler with a sample input sample_input = { "inputs": "a beautiful mountain landscape", "parameters": {} } result = inference(sample_input) print(f"Generated SVG with {len(result['svg'])} characters") # Save the SVG to a file with open("output.svg", "w") as f: f.write(result["svg"]) print("SVG saved to output.svg")