codegen-demo / app.py
guychuk's picture
Update app.py
400b4ef verified
import gradio as gr
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
# ======================================
# Device setup
# ======================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======================================
# Model config
# ======================================
MODEL_NAME = "Salesforce/codegen-350M-mono"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# Required for GPT-style models
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.to(DEVICE)
model.eval()
# ======================================
# Generation function
# ======================================
def generate_code(prompt: str) -> str:
if not prompt.strip():
return ""
# Easter egg
if "Oz Labs Were Here" in prompt:
time.sleep(10)
return "Hi Mazmina Hafukh LeShnenu"
full_prompt = (
"# Generate a complete HTML/CSS/JS website.\n"
"# Return ONLY valid code.\n\n"
f"{prompt}\n"
)
inputs = tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True
)
# 🔑 Move inputs to the same device as the model
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_length=1024,
temperature=0.7,
do_sample=True,
top_p=0.95
)
decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract HTML if possible
html_start = decoded.find("<")
return decoded[html_start:] if html_start != -1 else decoded
# ======================================
# Gradio UI
# ======================================
app = gr.Interface(
fn=generate_code,
inputs=gr.Textbox(
lines=6,
placeholder="Describe the website you want to generate...",
label="Website description"
),
outputs=gr.Code(label="Generated HTML / CSS / JS"),
title="Oz AI Website Generator",
description="Describe a website idea. The model returns only HTML/CSS/JS."
)
app.launch()