huohuobeixiaosile commited on
Commit
0f91209
·
verified ·
1 Parent(s): 77dd6fb

Upload W3-assignment-streamlit.py

Browse files
Files changed (1) hide show
  1. W3-assignment-streamlit.py +194 -0
W3-assignment-streamlit.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Assume our stakeholder is a brand advertiser who wants to understand short-video audience segments.
2
+ # The clustering analysis helps identify which audience groups are best suited for different types of
3
+ # campaigns (broad awareness vs. engagement-driven vs. knowledge-sharing).
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.preprocessing import StandardScaler
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.cluster import KMeans
13
+ from sklearn.metrics import silhouette_score
14
+
15
+ st.set_page_config(page_title="Short-Video Audience Segments", layout="wide")
16
+
17
+ # Sidebar — Controls
18
+ st.sidebar.title("Controls")
19
+
20
+ # Set CSV path here
21
+ CSV_PATH = "data/youtube_shorts_tiktok_trends_2025.csv"
22
+ df_raw = pd.read_csv(CSV_PATH, low_memory=False)
23
+
24
+ # Choose number of clusters (K)
25
+ k = st.sidebar.slider("Number of clusters (K)", min_value=2, max_value=8, value=4, step=1)
26
+
27
+ # Metric to compare across clusters
28
+ core_metrics = ["views","likes","comments","shares","saves","engagement_rate"]
29
+ metric = st.sidebar.selectbox("Metric to compare ", core_metrics, index=5)
30
+
31
+ # Optional filters for display (do not refit model)
32
+ platform_filter_on = st.sidebar.checkbox("Filter by platform ", value=False)
33
+ region_filter_on = st.sidebar.checkbox("Filter by region ", value=False)
34
+
35
+ # Data loading & preprocessing
36
+ @st.cache_data(show_spinner=True)
37
+ def load_data(path):
38
+ df = pd.read_csv(path, low_memory=False)
39
+ return df
40
+
41
+ @st.cache_data(show_spinner=True)
42
+ def build_features(df_raw: pd.DataFrame):
43
+ # Select useful columns for engagement & context
44
+ use_cols = [
45
+ "platform","region","category","sound_type",
46
+ "week_of_year","duration_sec",
47
+ "views","likes","comments","shares","saves",
48
+ "engagement_rate","engagement_share_rate"
49
+ ]
50
+ df = df_raw[use_cols].copy()
51
+ # Basic cleaning: keep valid duration & views
52
+ df = df[(df["duration_sec"] > 0) & (df["views"] > 0)]
53
+ # Construct per-view rates to offset scale bias(构造每次观看的比率,缓解规模差异)
54
+ for col in ["likes","comments","shares","saves"]:
55
+ df[f"{col}_rate"] = (df[col] / df["views"].clip(lower=1)).astype(float)
56
+
57
+ # Log-transform heavy-tailed count features
58
+ for col in ["views","likes","comments","shares","saves","duration_sec"]:
59
+ df[f"log_{col}"] = np.log1p(df[col])
60
+ # Week-of-year cyclic encoding
61
+ df["week_sin"] = np.sin(2*np.pi*df["week_of_year"]/52.0)
62
+ df["week_cos"] = np.cos(2*np.pi*df["week_of_year"]/52.0)
63
+
64
+ # Numeric + Categorical features
65
+ num_feats = [
66
+ "log_views","log_likes","log_comments","log_shares","log_saves",
67
+ "log_duration_sec",
68
+ "likes_rate","comments_rate","shares_rate","saves_rate",
69
+ "engagement_rate","engagement_share_rate",
70
+ "week_sin","week_cos"
71
+ ]
72
+ cat_feats = ["platform","region","category","sound_type"]
73
+
74
+ # One-hot for categoricals (drop_first to avoid perfect collinearity)
75
+ df_model = pd.get_dummies(df[num_feats + cat_feats], drop_first=True).astype(float)
76
+ feature_names = df_model.columns.tolist()
77
+
78
+ # Impute & Scale
79
+ imp = SimpleImputer(strategy="median")
80
+ X_num = imp.fit_transform(df_model.values)
81
+ scaler = StandardScaler()
82
+ X_scaled = scaler.fit_transform(X_num)
83
+
84
+ return df, df_model, feature_names, X_scaled
85
+
86
+ @st.cache_resource(show_spinner=True)
87
+ def fit_pca(X_scaled: np.ndarray, var_threshold: float = 0.80):
88
+ pca = PCA(n_components=None, random_state=42)
89
+ X_pca = pca.fit_transform(X_scaled)
90
+ exp = pca.explained_variance_ratio_
91
+ cum = np.cumsum(exp)
92
+ k_opt = int(np.argmax(cum >= var_threshold)) + 1
93
+ return pca, X_pca, exp, cum, k_opt
94
+
95
+ @st.cache_resource(show_spinner=True)
96
+ def fit_kmeans(X_embed: np.ndarray, n_clusters: int):
97
+ km = KMeans(n_clusters=n_clusters, random_state=42, n_init=20)
98
+ labels = km.fit_predict(X_embed)
99
+ try:
100
+ sil = silhouette_score(X_embed, labels)
101
+ except Exception:
102
+ sil = np.nan
103
+ return km, labels, sil
104
+
105
+ # Load & Prep
106
+ df_raw = load_data(csv_path)
107
+ df_clean, df_model, feature_names, X_scaled = build_features(df_raw)
108
+ pca, X_pca, exp, cum, k_opt = fit_pca(X_scaled, var_threshold=0.80)
109
+
110
+ # Use first k_opt PCs for clustering
111
+ X_k = X_pca[:, :k_opt]
112
+
113
+ # Fit KMeans with selected k
114
+ km, labels, sil = fit_kmeans(X_k, n_clusters=k)
115
+ df_show = df_clean.copy()
116
+ df_show["cluster"] = labels
117
+
118
+ # Optional display filters (platform/region)
119
+ if platform_filter_on:
120
+ platforms = ["(All)"] + sorted(df_show["platform"].dropna().unique().tolist())
121
+ chosen_platform = st.sidebar.selectbox("Platform", platforms, index=0)
122
+ if chosen_platform != "(All)":
123
+ df_show = df_show[df_show["platform"] == chosen_platform]
124
+
125
+ if region_filter_on:
126
+ regions = ["(All)"] + sorted(df_show["region"].dropna().unique().tolist())
127
+ chosen_region = st.sidebar.selectbox("Region", regions, index=0)
128
+ if chosen_region != "(All)":
129
+ df_show = df_show[df_show["region"] == chosen_region]
130
+
131
+ # Header — Business framing
132
+ st.title("Short-Video Audience Segments (TikTok & YouTube Shorts)")
133
+ st.caption("We identify audience segments via PCA + KMeans to support content and ad strategy.")
134
+
135
+ colA, colB, colC = st.columns(3)
136
+ with colA:
137
+ st.metric("Rows used", f"{len(df_show):,}")
138
+ with colB:
139
+ st.metric("K (clusters)", k)
140
+ with colC:
141
+ st.metric("Silhouette", f"{sil:.3f}" if not np.isnan(sil) else "—")
142
+
143
+ st.markdown("---")
144
+
145
+ # Cluster profile table
146
+ cluster_profile = (
147
+ df_show.groupby("cluster")[core_metrics]
148
+ .mean()
149
+ .sort_index()
150
+ )
151
+
152
+ st.subheader("Cluster Profiles — Mean Metrics")
153
+ st.dataframe(cluster_profile.style.format("{:,.2f}"))
154
+
155
+ # Chart: Compare chosen metric across clusters
156
+ st.subheader(f"Compare clusters by: **{metric}**")
157
+
158
+ fig, ax = plt.subplots(figsize=(7,4))
159
+ cluster_profile[metric].plot(kind="bar", ax=ax)
160
+ ax.set_xlabel("Cluster")
161
+ ax.set_ylabel(metric)
162
+ ax.set_title(f"{metric} by Cluster")
163
+ ax.grid(axis="y", linestyle="--", alpha=0.3)
164
+ st.pyplot(fig, use_container_width=True)
165
+
166
+ # Dynamic insights
167
+ def insight_text(cp: pd.DataFrame):
168
+ # Find clusters with max metrics
169
+ tops = {m: int(cp[m].idxmax()) for m in cp.columns}
170
+ lines_en = [
171
+ f"- Highest **engagement_rate**: Cluster {tops['engagement_rate']} (best for community/interaction).",
172
+ f"- Highest **views**: Cluster {tops['views']} (best for broad awareness).",
173
+ f"- Highest **saves**: Cluster {tops['saves']} (good for knowledge/utility content).",
174
+ ]
175
+ return "\n".join(lines_en)
176
+
177
+ st.markdown("### Dynamic Insights")
178
+ st.markdown(insight_text(cluster_profile))
179
+
180
+ # Optional: Diagnostics
181
+ with st.expander("Model Diagnostics"):
182
+ st.write(f"Using first **{k_opt} PCs** to reach ≥80% cumulative explained variance.")
183
+ # small curve
184
+ fig2, ax2 = plt.subplots(figsize=(5,3))
185
+ ax2.plot(range(1, len(cum)+1), cum, marker="o")
186
+ ax2.axhline(0.80, color="r", linestyle="--")
187
+ ax2.set_xlabel("PCs")
188
+ ax2.set_ylabel("Cumulative explained variance")
189
+ ax2.set_title("PCA Explained Variance")
190
+ ax2.grid(axis="y", linestyle="--", alpha=0.3)
191
+ st.pyplot(fig2, use_container_width=False)
192
+
193
+ st.markdown("---")
194
+ st.caption("Note: Clustering is fitted on full data (then filtered for display).")