Spaces:
Runtime error
Runtime error
Upload W3-assignment-streamlit.py
Browse files- W3-assignment-streamlit.py +194 -0
W3-assignment-streamlit.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Assume our stakeholder is a brand advertiser who wants to understand short-video audience segments.
|
| 2 |
+
# The clustering analysis helps identify which audience groups are best suited for different types of
|
| 3 |
+
# campaigns (broad awareness vs. engagement-driven vs. knowledge-sharing).
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
from sklearn.impute import SimpleImputer
|
| 10 |
+
from sklearn.preprocessing import StandardScaler
|
| 11 |
+
from sklearn.decomposition import PCA
|
| 12 |
+
from sklearn.cluster import KMeans
|
| 13 |
+
from sklearn.metrics import silhouette_score
|
| 14 |
+
|
| 15 |
+
st.set_page_config(page_title="Short-Video Audience Segments", layout="wide")
|
| 16 |
+
|
| 17 |
+
# Sidebar — Controls
|
| 18 |
+
st.sidebar.title("Controls")
|
| 19 |
+
|
| 20 |
+
# Set CSV path here
|
| 21 |
+
CSV_PATH = "data/youtube_shorts_tiktok_trends_2025.csv"
|
| 22 |
+
df_raw = pd.read_csv(CSV_PATH, low_memory=False)
|
| 23 |
+
|
| 24 |
+
# Choose number of clusters (K)
|
| 25 |
+
k = st.sidebar.slider("Number of clusters (K)", min_value=2, max_value=8, value=4, step=1)
|
| 26 |
+
|
| 27 |
+
# Metric to compare across clusters
|
| 28 |
+
core_metrics = ["views","likes","comments","shares","saves","engagement_rate"]
|
| 29 |
+
metric = st.sidebar.selectbox("Metric to compare ", core_metrics, index=5)
|
| 30 |
+
|
| 31 |
+
# Optional filters for display (do not refit model)
|
| 32 |
+
platform_filter_on = st.sidebar.checkbox("Filter by platform ", value=False)
|
| 33 |
+
region_filter_on = st.sidebar.checkbox("Filter by region ", value=False)
|
| 34 |
+
|
| 35 |
+
# Data loading & preprocessing
|
| 36 |
+
@st.cache_data(show_spinner=True)
|
| 37 |
+
def load_data(path):
|
| 38 |
+
df = pd.read_csv(path, low_memory=False)
|
| 39 |
+
return df
|
| 40 |
+
|
| 41 |
+
@st.cache_data(show_spinner=True)
|
| 42 |
+
def build_features(df_raw: pd.DataFrame):
|
| 43 |
+
# Select useful columns for engagement & context
|
| 44 |
+
use_cols = [
|
| 45 |
+
"platform","region","category","sound_type",
|
| 46 |
+
"week_of_year","duration_sec",
|
| 47 |
+
"views","likes","comments","shares","saves",
|
| 48 |
+
"engagement_rate","engagement_share_rate"
|
| 49 |
+
]
|
| 50 |
+
df = df_raw[use_cols].copy()
|
| 51 |
+
# Basic cleaning: keep valid duration & views
|
| 52 |
+
df = df[(df["duration_sec"] > 0) & (df["views"] > 0)]
|
| 53 |
+
# Construct per-view rates to offset scale bias(构造每次观看的比率,缓解规模差异)
|
| 54 |
+
for col in ["likes","comments","shares","saves"]:
|
| 55 |
+
df[f"{col}_rate"] = (df[col] / df["views"].clip(lower=1)).astype(float)
|
| 56 |
+
|
| 57 |
+
# Log-transform heavy-tailed count features
|
| 58 |
+
for col in ["views","likes","comments","shares","saves","duration_sec"]:
|
| 59 |
+
df[f"log_{col}"] = np.log1p(df[col])
|
| 60 |
+
# Week-of-year cyclic encoding
|
| 61 |
+
df["week_sin"] = np.sin(2*np.pi*df["week_of_year"]/52.0)
|
| 62 |
+
df["week_cos"] = np.cos(2*np.pi*df["week_of_year"]/52.0)
|
| 63 |
+
|
| 64 |
+
# Numeric + Categorical features
|
| 65 |
+
num_feats = [
|
| 66 |
+
"log_views","log_likes","log_comments","log_shares","log_saves",
|
| 67 |
+
"log_duration_sec",
|
| 68 |
+
"likes_rate","comments_rate","shares_rate","saves_rate",
|
| 69 |
+
"engagement_rate","engagement_share_rate",
|
| 70 |
+
"week_sin","week_cos"
|
| 71 |
+
]
|
| 72 |
+
cat_feats = ["platform","region","category","sound_type"]
|
| 73 |
+
|
| 74 |
+
# One-hot for categoricals (drop_first to avoid perfect collinearity)
|
| 75 |
+
df_model = pd.get_dummies(df[num_feats + cat_feats], drop_first=True).astype(float)
|
| 76 |
+
feature_names = df_model.columns.tolist()
|
| 77 |
+
|
| 78 |
+
# Impute & Scale
|
| 79 |
+
imp = SimpleImputer(strategy="median")
|
| 80 |
+
X_num = imp.fit_transform(df_model.values)
|
| 81 |
+
scaler = StandardScaler()
|
| 82 |
+
X_scaled = scaler.fit_transform(X_num)
|
| 83 |
+
|
| 84 |
+
return df, df_model, feature_names, X_scaled
|
| 85 |
+
|
| 86 |
+
@st.cache_resource(show_spinner=True)
|
| 87 |
+
def fit_pca(X_scaled: np.ndarray, var_threshold: float = 0.80):
|
| 88 |
+
pca = PCA(n_components=None, random_state=42)
|
| 89 |
+
X_pca = pca.fit_transform(X_scaled)
|
| 90 |
+
exp = pca.explained_variance_ratio_
|
| 91 |
+
cum = np.cumsum(exp)
|
| 92 |
+
k_opt = int(np.argmax(cum >= var_threshold)) + 1
|
| 93 |
+
return pca, X_pca, exp, cum, k_opt
|
| 94 |
+
|
| 95 |
+
@st.cache_resource(show_spinner=True)
|
| 96 |
+
def fit_kmeans(X_embed: np.ndarray, n_clusters: int):
|
| 97 |
+
km = KMeans(n_clusters=n_clusters, random_state=42, n_init=20)
|
| 98 |
+
labels = km.fit_predict(X_embed)
|
| 99 |
+
try:
|
| 100 |
+
sil = silhouette_score(X_embed, labels)
|
| 101 |
+
except Exception:
|
| 102 |
+
sil = np.nan
|
| 103 |
+
return km, labels, sil
|
| 104 |
+
|
| 105 |
+
# Load & Prep
|
| 106 |
+
df_raw = load_data(csv_path)
|
| 107 |
+
df_clean, df_model, feature_names, X_scaled = build_features(df_raw)
|
| 108 |
+
pca, X_pca, exp, cum, k_opt = fit_pca(X_scaled, var_threshold=0.80)
|
| 109 |
+
|
| 110 |
+
# Use first k_opt PCs for clustering
|
| 111 |
+
X_k = X_pca[:, :k_opt]
|
| 112 |
+
|
| 113 |
+
# Fit KMeans with selected k
|
| 114 |
+
km, labels, sil = fit_kmeans(X_k, n_clusters=k)
|
| 115 |
+
df_show = df_clean.copy()
|
| 116 |
+
df_show["cluster"] = labels
|
| 117 |
+
|
| 118 |
+
# Optional display filters (platform/region)
|
| 119 |
+
if platform_filter_on:
|
| 120 |
+
platforms = ["(All)"] + sorted(df_show["platform"].dropna().unique().tolist())
|
| 121 |
+
chosen_platform = st.sidebar.selectbox("Platform", platforms, index=0)
|
| 122 |
+
if chosen_platform != "(All)":
|
| 123 |
+
df_show = df_show[df_show["platform"] == chosen_platform]
|
| 124 |
+
|
| 125 |
+
if region_filter_on:
|
| 126 |
+
regions = ["(All)"] + sorted(df_show["region"].dropna().unique().tolist())
|
| 127 |
+
chosen_region = st.sidebar.selectbox("Region", regions, index=0)
|
| 128 |
+
if chosen_region != "(All)":
|
| 129 |
+
df_show = df_show[df_show["region"] == chosen_region]
|
| 130 |
+
|
| 131 |
+
# Header — Business framing
|
| 132 |
+
st.title("Short-Video Audience Segments (TikTok & YouTube Shorts)")
|
| 133 |
+
st.caption("We identify audience segments via PCA + KMeans to support content and ad strategy.")
|
| 134 |
+
|
| 135 |
+
colA, colB, colC = st.columns(3)
|
| 136 |
+
with colA:
|
| 137 |
+
st.metric("Rows used", f"{len(df_show):,}")
|
| 138 |
+
with colB:
|
| 139 |
+
st.metric("K (clusters)", k)
|
| 140 |
+
with colC:
|
| 141 |
+
st.metric("Silhouette", f"{sil:.3f}" if not np.isnan(sil) else "—")
|
| 142 |
+
|
| 143 |
+
st.markdown("---")
|
| 144 |
+
|
| 145 |
+
# Cluster profile table
|
| 146 |
+
cluster_profile = (
|
| 147 |
+
df_show.groupby("cluster")[core_metrics]
|
| 148 |
+
.mean()
|
| 149 |
+
.sort_index()
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
st.subheader("Cluster Profiles — Mean Metrics")
|
| 153 |
+
st.dataframe(cluster_profile.style.format("{:,.2f}"))
|
| 154 |
+
|
| 155 |
+
# Chart: Compare chosen metric across clusters
|
| 156 |
+
st.subheader(f"Compare clusters by: **{metric}**")
|
| 157 |
+
|
| 158 |
+
fig, ax = plt.subplots(figsize=(7,4))
|
| 159 |
+
cluster_profile[metric].plot(kind="bar", ax=ax)
|
| 160 |
+
ax.set_xlabel("Cluster")
|
| 161 |
+
ax.set_ylabel(metric)
|
| 162 |
+
ax.set_title(f"{metric} by Cluster")
|
| 163 |
+
ax.grid(axis="y", linestyle="--", alpha=0.3)
|
| 164 |
+
st.pyplot(fig, use_container_width=True)
|
| 165 |
+
|
| 166 |
+
# Dynamic insights
|
| 167 |
+
def insight_text(cp: pd.DataFrame):
|
| 168 |
+
# Find clusters with max metrics
|
| 169 |
+
tops = {m: int(cp[m].idxmax()) for m in cp.columns}
|
| 170 |
+
lines_en = [
|
| 171 |
+
f"- Highest **engagement_rate**: Cluster {tops['engagement_rate']} (best for community/interaction).",
|
| 172 |
+
f"- Highest **views**: Cluster {tops['views']} (best for broad awareness).",
|
| 173 |
+
f"- Highest **saves**: Cluster {tops['saves']} (good for knowledge/utility content).",
|
| 174 |
+
]
|
| 175 |
+
return "\n".join(lines_en)
|
| 176 |
+
|
| 177 |
+
st.markdown("### Dynamic Insights")
|
| 178 |
+
st.markdown(insight_text(cluster_profile))
|
| 179 |
+
|
| 180 |
+
# Optional: Diagnostics
|
| 181 |
+
with st.expander("Model Diagnostics"):
|
| 182 |
+
st.write(f"Using first **{k_opt} PCs** to reach ≥80% cumulative explained variance.")
|
| 183 |
+
# small curve
|
| 184 |
+
fig2, ax2 = plt.subplots(figsize=(5,3))
|
| 185 |
+
ax2.plot(range(1, len(cum)+1), cum, marker="o")
|
| 186 |
+
ax2.axhline(0.80, color="r", linestyle="--")
|
| 187 |
+
ax2.set_xlabel("PCs")
|
| 188 |
+
ax2.set_ylabel("Cumulative explained variance")
|
| 189 |
+
ax2.set_title("PCA Explained Variance")
|
| 190 |
+
ax2.grid(axis="y", linestyle="--", alpha=0.3)
|
| 191 |
+
st.pyplot(fig2, use_container_width=False)
|
| 192 |
+
|
| 193 |
+
st.markdown("---")
|
| 194 |
+
st.caption("Note: Clustering is fitted on full data (then filtered for display).")
|