madhavkarthi's picture
Update app.py
1687d11 verified
import gradio as gr
import pandas as pd
import cloudpickle
from huggingface_hub import hf_hub_download
import autogluon.tabular as ag
import zipfile
import pathlib
import shutil
# -----------------------------
# 1) Load pickled AutoGluon predictor
# -----------------------------
pkl_path = hf_hub_download(
repo_id="cassieli226/hw2-airline-automl",
filename="autogluon_predictor.pkl",
repo_type="model"
)
with open(pkl_path, "rb") as f:
predictor = cloudpickle.load(f)
# -----------------------------
# 2) load predictor directory for leaderboard
# -----------------------------
zip_path = hf_hub_download(
repo_id="cassieli226/hw2-airline-automl",
filename="autogluon_predictor_dir.zip",
repo_type="model"
)
extract_dir = pathlib.Path("predictor_dir")
if extract_dir.exists():
shutil.rmtree(extract_dir)
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(str(extract_dir))
predictor_dir = ag.TabularPredictor.load(str(extract_dir))
# -----------------------------
# 3) Gradio interface function
# -----------------------------
def predict_flight(stops, days_from_departure, flight_time, price, day_of_week, destination):
X = pd.DataFrame({
"Stops": [stops],
"Days from Departure": [days_from_departure],
"Flight_Time_Minutes": [flight_time],
"Price": [price],
"Day of the Week": [day_of_week],
"Destination": [destination]
})
return predictor.predict(X)[0]
# -----------------------------
# 4) Gradio UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Flight Duration Predictor")
with gr.Row():
with gr.Column():
stops_in = gr.Slider("Stops", 0, 3, 1)
days_in = gr.Slider("Days from Departure", 0, 365, 30)
flight_time_in = gr.Slider("Flight Time (Minutes)", 30, 600, 120)
with gr.Column():
price_in = gr.Slider("Price", 50, 1000, 150)
day_in = gr.Dropdown([1,2,3,4,5,6,7], label="Day of the Week")
dest_in = gr.Textbox(label="Destination")
predict_btn = gr.Button("Predict")
output = gr.Textbox(label="Predicted Flight Duration")
predict_btn.click(
predict_flight,
inputs=[stops_in, days_in, flight_time_in, price_in, day_in, dest_in],
outputs=[output]
)
# -----------------------------
# 5) Launch
# -----------------------------
demo.launch()