File size: 2,423 Bytes
a8083c8
84201ba
a8083c8
84201ba
a8083c8
 
 
 
 
1687d11
 
 
a8083c8
 
 
 
 
84201ba
a8083c8
 
 
1687d11
 
 
84201ba
 
 
 
 
 
 
 
 
 
 
 
 
a8083c8
84201ba
1687d11
 
 
84201ba
a8083c8
 
 
 
 
 
 
 
 
84201ba
1687d11
 
 
84201ba
a8083c8
1687d11
84201ba
 
a8083c8
 
 
84201ba
a8083c8
 
 
84201ba
a8083c8
 
84201ba
a8083c8
 
 
 
84201ba
 
1687d11
 
 
a8083c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import pandas as pd
import cloudpickle
from huggingface_hub import hf_hub_download
import autogluon.tabular as ag
import zipfile
import pathlib
import shutil

# -----------------------------
# 1) Load pickled AutoGluon predictor
# -----------------------------
pkl_path = hf_hub_download(
    repo_id="cassieli226/hw2-airline-automl",
    filename="autogluon_predictor.pkl",
    repo_type="model"
)

with open(pkl_path, "rb") as f:
    predictor = cloudpickle.load(f)

# -----------------------------
# 2) load predictor directory for leaderboard
# -----------------------------
zip_path = hf_hub_download(
    repo_id="cassieli226/hw2-airline-automl",
    filename="autogluon_predictor_dir.zip",
    repo_type="model"
)

extract_dir = pathlib.Path("predictor_dir")
if extract_dir.exists():
    shutil.rmtree(extract_dir)

with zipfile.ZipFile(zip_path, "r") as zf:
    zf.extractall(str(extract_dir))

predictor_dir = ag.TabularPredictor.load(str(extract_dir))

# -----------------------------
# 3) Gradio interface function
# -----------------------------
def predict_flight(stops, days_from_departure, flight_time, price, day_of_week, destination):
    X = pd.DataFrame({
        "Stops": [stops],
        "Days from Departure": [days_from_departure],
        "Flight_Time_Minutes": [flight_time],
        "Price": [price],
        "Day of the Week": [day_of_week],
        "Destination": [destination]
    })
    return predictor.predict(X)[0]

# -----------------------------
# 4) Gradio UI
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Flight Duration Predictor")

    with gr.Row():
        with gr.Column():
            stops_in = gr.Slider("Stops", 0, 3, 1)
            days_in = gr.Slider("Days from Departure", 0, 365, 30)
            flight_time_in = gr.Slider("Flight Time (Minutes)", 30, 600, 120)
        with gr.Column():
            price_in = gr.Slider("Price", 50, 1000, 150)
            day_in = gr.Dropdown([1,2,3,4,5,6,7], label="Day of the Week")
            dest_in = gr.Textbox(label="Destination")

    predict_btn = gr.Button("Predict")
    output = gr.Textbox(label="Predicted Flight Duration")

    predict_btn.click(
        predict_flight,
        inputs=[stops_in, days_in, flight_time_in, price_in, day_in, dest_in],
        outputs=[output]
    )

# -----------------------------
# 5) Launch
# -----------------------------
demo.launch()