W3 / W3-assignment-streamlit.py
huohuobeixiaosile's picture
Update W3-assignment-streamlit.py
d5c89f1 verified
raw
history blame
7.85 kB
# Assume our stakeholder is a brand advertiser who wants to understand short-video audience segments.
# The clustering analysis helps identify which audience groups are best suited for different types of
# campaigns (broad awareness vs. engagement-driven vs. knowledge-sharing).
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
st.set_page_config(page_title="Short-Video Audience Segments", layout="wide")
# Sidebar — Controls
st.sidebar.title("Controls")
# Set CSV path (try both possible locations)
import os
from pathlib import Path
@st.cache_data(show_spinner=False)
def load_data() -> pd.DataFrame:
"""
Try both locations:
1) data/youtube_shorts_tiktok_trends_2025.csv
2) ./youtube_shorts_tiktok_trends_2025.csv
"""
candidates = [
Path("data/youtube_shorts_tiktok_trends_2025.csv"),
Path("youtube_shorts_tiktok_trends_2025.csv"),
]
for p in candidates:
if p.exists():
return pd.read_csv(p, low_memory=False)
raise FileNotFoundError(
f"CSV not found in any of: {[str(p) for p in candidates]}. "
"Make sure the file is committed to your Space."
)
# load data
df_raw = load_data()
# Choose number of clusters (K)
k = st.sidebar.slider("Number of clusters (K)", min_value=2, max_value=8, value=4, step=1)
# Metric to compare across clusters
core_metrics = ["views","likes","comments","shares","saves","engagement_rate"]
metric = st.sidebar.selectbox("Metric to compare ", core_metrics, index=5)
# Optional filters for display (do not refit model)
platform_filter_on = st.sidebar.checkbox("Filter by platform ", value=False)
region_filter_on = st.sidebar.checkbox("Filter by region ", value=False)
# Data loading & preprocessing
@st.cache_data(show_spinner=True)
def load_data():
df = pd.read_csv(path, low_memory=False)
return df
@st.cache_data(show_spinner=True)
def build_features(df_raw: pd.DataFrame):
# Select useful columns for engagement & context
use_cols = [
"platform","region","category","sound_type",
"week_of_year","duration_sec",
"views","likes","comments","shares","saves",
"engagement_rate","engagement_share_rate"
]
df = df_raw[use_cols].copy()
# Basic cleaning: keep valid duration & views
df = df[(df["duration_sec"] > 0) & (df["views"] > 0)]
# Construct per-view rates to offset scale bias(构造每次观看的比率,缓解规模差异)
for col in ["likes","comments","shares","saves"]:
df[f"{col}_rate"] = (df[col] / df["views"].clip(lower=1)).astype(float)
# Log-transform heavy-tailed count features
for col in ["views","likes","comments","shares","saves","duration_sec"]:
df[f"log_{col}"] = np.log1p(df[col])
# Week-of-year cyclic encoding
df["week_sin"] = np.sin(2*np.pi*df["week_of_year"]/52.0)
df["week_cos"] = np.cos(2*np.pi*df["week_of_year"]/52.0)
# Numeric + Categorical features
num_feats = [
"log_views","log_likes","log_comments","log_shares","log_saves",
"log_duration_sec",
"likes_rate","comments_rate","shares_rate","saves_rate",
"engagement_rate","engagement_share_rate",
"week_sin","week_cos"
]
cat_feats = ["platform","region","category","sound_type"]
# One-hot for categoricals (drop_first to avoid perfect collinearity)
df_model = pd.get_dummies(df[num_feats + cat_feats], drop_first=True).astype(float)
feature_names = df_model.columns.tolist()
# Impute & Scale
imp = SimpleImputer(strategy="median")
X_num = imp.fit_transform(df_model.values)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_num)
return df, df_model, feature_names, X_scaled
@st.cache_resource(show_spinner=True)
def fit_pca(X_scaled: np.ndarray, var_threshold: float = 0.80):
pca = PCA(n_components=None, random_state=42)
X_pca = pca.fit_transform(X_scaled)
exp = pca.explained_variance_ratio_
cum = np.cumsum(exp)
k_opt = int(np.argmax(cum >= var_threshold)) + 1
return pca, X_pca, exp, cum, k_opt
@st.cache_resource(show_spinner=True)
def fit_kmeans(X_embed: np.ndarray, n_clusters: int):
km = KMeans(n_clusters=n_clusters, random_state=42, n_init=20)
labels = km.fit_predict(X_embed)
try:
sil = silhouette_score(X_embed, labels)
except Exception:
sil = np.nan
return km, labels, sil
# Load & Prep
df_raw = load_data(csv_path)
df_clean, df_model, feature_names, X_scaled = build_features(df_raw)
pca, X_pca, exp, cum, k_opt = fit_pca(X_scaled, var_threshold=0.80)
# Use first k_opt PCs for clustering
X_k = X_pca[:, :k_opt]
# Fit KMeans with selected k
km, labels, sil = fit_kmeans(X_k, n_clusters=k)
df_show = df_clean.copy()
df_show["cluster"] = labels
# Optional display filters (platform/region)
if platform_filter_on:
platforms = ["(All)"] + sorted(df_show["platform"].dropna().unique().tolist())
chosen_platform = st.sidebar.selectbox("Platform", platforms, index=0)
if chosen_platform != "(All)":
df_show = df_show[df_show["platform"] == chosen_platform]
if region_filter_on:
regions = ["(All)"] + sorted(df_show["region"].dropna().unique().tolist())
chosen_region = st.sidebar.selectbox("Region", regions, index=0)
if chosen_region != "(All)":
df_show = df_show[df_show["region"] == chosen_region]
# Header — Business framing
st.title("Short-Video Audience Segments (TikTok & YouTube Shorts)")
st.caption("We identify audience segments via PCA + KMeans to support content and ad strategy.")
colA, colB, colC = st.columns(3)
with colA:
st.metric("Rows used", f"{len(df_show):,}")
with colB:
st.metric("K (clusters)", k)
with colC:
st.metric("Silhouette", f"{sil:.3f}" if not np.isnan(sil) else "—")
st.markdown("---")
# Cluster profile table
cluster_profile = (
df_show.groupby("cluster")[core_metrics]
.mean()
.sort_index()
)
st.subheader("Cluster Profiles — Mean Metrics")
st.dataframe(cluster_profile.style.format("{:,.2f}"))
# Chart: Compare chosen metric across clusters
st.subheader(f"Compare clusters by: **{metric}**")
fig, ax = plt.subplots(figsize=(7,4))
cluster_profile[metric].plot(kind="bar", ax=ax)
ax.set_xlabel("Cluster")
ax.set_ylabel(metric)
ax.set_title(f"{metric} by Cluster")
ax.grid(axis="y", linestyle="--", alpha=0.3)
st.pyplot(fig, use_container_width=True)
# Dynamic insights
def insight_text(cp: pd.DataFrame):
# Find clusters with max metrics
tops = {m: int(cp[m].idxmax()) for m in cp.columns}
lines_en = [
f"- Highest **engagement_rate**: Cluster {tops['engagement_rate']} (best for community/interaction).",
f"- Highest **views**: Cluster {tops['views']} (best for broad awareness).",
f"- Highest **saves**: Cluster {tops['saves']} (good for knowledge/utility content).",
]
return "\n".join(lines_en)
st.markdown("### Dynamic Insights")
st.markdown(insight_text(cluster_profile))
# Optional: Diagnostics
with st.expander("Model Diagnostics"):
st.write(f"Using first **{k_opt} PCs** to reach ≥80% cumulative explained variance.")
# small curve
fig2, ax2 = plt.subplots(figsize=(5,3))
ax2.plot(range(1, len(cum)+1), cum, marker="o")
ax2.axhline(0.80, color="r", linestyle="--")
ax2.set_xlabel("PCs")
ax2.set_ylabel("Cumulative explained variance")
ax2.set_title("PCA Explained Variance")
ax2.grid(axis="y", linestyle="--", alpha=0.3)
st.pyplot(fig2, use_container_width=False)
st.markdown("---")
st.caption("Note: Clustering is fitted on full data (then filtered for display).")