sunbal7 commited on
Commit
395d96d
Β·
verified Β·
1 Parent(s): 59e6f9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -200
app.py CHANGED
@@ -1,31 +1,35 @@
1
- import streamlit as st
2
- import cv2
3
  import tempfile
4
- import numpy as np
5
- from ultralytics import YOLO
6
- import plotly.graph_objects as go
7
  from collections import defaultdict
8
- import pandas as pd
9
- import os
10
 
11
- # Optional: yt-dlp for YouTube download
12
- import yt_dlp
 
 
 
 
 
13
 
14
- # --- Configuration & Initialization ---
 
 
 
 
 
15
 
16
- st.set_page_config(
17
- page_title="YOLOv8 Object Tracking & Counter",
18
- page_icon="πŸ€–",
19
- layout="wide"
20
- )
21
 
22
  st.title("🚦 Smart Object Traffic Analyzer (YOLOv8)")
23
- st.markdown("""
24
- Process local files or YouTube links to track and count unique object crossings.
25
- Uses YOLOv8 detection with integrated ByteTrack tracking for robust multi-object analysis.
26
- """)
 
 
27
 
28
- # COCO classes (subset commonly used)
29
  COCO_CLASS_NAMES = {
30
  0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
31
  5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light"
@@ -40,114 +44,147 @@ CLASS_MAPPING = {
40
  "Truck": 7,
41
  }
42
 
 
43
  if "processed_data" not in st.session_state:
44
  st.session_state.processed_data = {
45
  "total_counts": defaultdict(int),
46
  "frame_counts": [],
47
  "processed_video": None,
48
  "processing_complete": False,
49
- "tracked_objects": {}, # id -> {"class": str, "last_centroid": (x, y), "counted": bool}
50
  }
51
 
52
- # --- Sidebar ---
53
-
54
  with st.sidebar:
55
- st.header("βš™οΈ Configuration settings")
56
 
57
  st.subheader("Model & detection")
58
- model_name = st.selectbox(
59
- "Select YOLO model",
60
- options=["yolov8n.pt", "yolov8s.pt"],
61
- help="Nano (n) is fast; Small (s) is more accurate."
62
- )
63
 
64
- confidence = st.slider(
65
- "Detection confidence threshold",
66
- min_value=0.1, max_value=1.0, value=0.40, step=0.05,
67
- help="Minimum confidence to consider a detection valid."
68
- )
69
 
70
  st.subheader("Objects for counting")
71
  selected_classes_ui = {}
72
- for name, cid in CLASS_MAPPING.items():
73
  default_val = name in ["Person", "Car"]
74
  selected_classes_ui[name] = st.checkbox(name, value=default_val)
75
 
76
  st.subheader("Counting line settings")
77
  show_line = st.checkbox("Show crossing line", value=True)
78
- line_position = st.slider(
79
- "Line position (vertical % from left)",
80
- min_value=10, max_value=90, value=50,
81
- help="Place the vertical line at a percentage of frame width."
82
- )
83
 
84
  st.subheader("Performance options")
85
- process_every_nth = st.slider(
86
- "Frame skip (process every Nth frame)",
87
- min_value=1, max_value=10, value=2,
88
- help="Higher values speed up processing but reduce tracking smoothness."
89
- )
90
 
91
- max_frames = st.number_input(
92
- "Maximum frames to analyze",
93
- min_value=10, max_value=5000, value=500,
94
- help="Limit processing for long videos. Use large values for full videos."
95
- )
96
 
97
  # --- Helpers ---
98
 
99
  @st.cache_resource
100
  def load_model(model_path: str):
 
101
  return YOLO(model_path)
102
 
103
- def get_selected_class_ids():
104
- return [CLASS_MAPPING[name] for name, is_selected in selected_classes_ui.items() if is_selected]
 
 
 
105
 
106
  @st.cache_data
107
- def download_youtube_video(youtube_url: str) -> str | None:
108
  """
109
- Download a YouTube video using yt-dlp to a temporary MP4 file and return its path.
110
- Attempts to select an MP4-compatible format; falls back to best available.
111
  """
112
  try:
113
- # Create temp directory and output path pattern
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  temp_dir = tempfile.mkdtemp()
115
  output_template = os.path.join(temp_dir, "video.%(ext)s")
116
 
117
- # Prefer MP4 H.264/AAC for broad compatibility
118
  ydl_opts = {
119
- "format": "best[ext=mp4]/best", # prefer mp4
120
  "outtmpl": output_template,
121
  "noplaylist": True,
122
  "quiet": True,
123
  "no_warnings": True,
124
- "retries": 3,
125
- "http_chunk_size": 10485760, # 10MB chunks for stability
126
  "merge_output_format": "mp4",
127
  }
128
 
129
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
130
  info = ydl.extract_info(youtube_url, download=True)
131
- # Resolve final filename
132
  filename = ydl.prepare_filename(info)
133
- # If not mp4, try merged file name with mp4
134
  if not filename.endswith(".mp4"):
135
  mp4_candidate = os.path.splitext(filename)[0] + ".mp4"
136
  if os.path.exists(mp4_candidate):
137
  filename = mp4_candidate
138
  if os.path.exists(filename):
139
- return filename
140
  else:
141
- return None
 
 
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
- st.error(f"Failed to download YouTube video: {e}")
144
- return None
145
 
146
- # --- Core processing ---
147
 
148
- def process_video(video_path: str, selected_class_ids: list[int], model_path: str):
 
 
 
 
 
 
149
  model = load_model(model_path)
 
150
  cap = cv2.VideoCapture(video_path)
 
 
 
151
 
152
  fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
153
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
@@ -155,12 +192,12 @@ def process_video(video_path: str, selected_class_ids: list[int], model_path: st
155
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
156
 
157
  if total_frames > max_frames:
158
- st.warning(f"Video will be processed for the first {max_frames} frames only.")
159
 
160
  temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
161
  output_path = temp_output.name
162
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
163
- out = cv2.VideoWriter(output_path, fourcc, max(fps / process_every_nth, 1), (width, height))
164
 
165
  state = st.session_state.processed_data
166
  state["total_counts"] = defaultdict(int)
@@ -184,73 +221,94 @@ def process_video(video_path: str, selected_class_ids: list[int], model_path: st
184
  if frame_idx % process_every_nth != 0:
185
  continue
186
 
187
- results = model.track(
188
- frame,
189
- conf=confidence,
190
- classes=selected_class_ids if selected_class_ids else None,
191
- persist=True,
192
- tracker="bytetrack.yaml",
193
- verbose=False
194
- )
 
 
 
 
 
195
 
196
  annotated = frame.copy()
197
  frame_counts = defaultdict(int)
198
 
199
- if results and hasattr(results[0], "boxes") and results[0].boxes.id is not None:
200
- boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
201
- track_ids = results[0].boxes.id.cpu().numpy().astype(int)
202
- class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
203
-
204
- for box, tid, cid in zip(boxes, track_ids, class_ids):
205
- x1, y1, x2, y2 = box
206
- cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
207
- centroid = (cx, cy)
208
-
209
- cls_name = COCO_CLASS_NAMES.get(cid, "Unknown")
210
- frame_counts[cls_name.lower()] += 1
211
-
212
- if tid not in state["tracked_objects"]:
213
- state["tracked_objects"][tid] = {
214
- "class": cls_name,
215
- "last_centroid": centroid,
216
- "counted": False
217
- }
218
- else:
219
- obj = state["tracked_objects"][tid]
220
- prev_x = obj["last_centroid"][0]
221
- if not obj["counted"]:
222
- crossed_right = prev_x < line_x and cx >= line_x
223
- crossed_left = prev_x > line_x and cx <= line_x
224
- if crossed_right or crossed_left:
225
- state["total_counts"][cls_name] += 1
226
- obj["counted"] = True
227
- obj["last_centroid"] = centroid
228
-
229
- cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 0), 2)
230
- cv2.circle(annotated, (cx, cy), 5, (0, 0, 255), -1)
231
- cv2.putText(
232
- annotated, f"ID:{tid} {cls_name}", (x1, max(10, y1 - 10)),
233
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2
234
- )
235
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  if show_line:
237
  line_color = (0, 255, 255)
238
  cv2.line(annotated, (line_x, 0), (line_x, height), line_color, 2)
239
- cv2.putText(
240
- annotated, "COUNTING LINE", (min(width - 180, line_x + 5), 20),
241
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, line_color, 2
242
- )
243
 
244
  y_offset = 30
245
  for obj_type, count in state["total_counts"].items():
246
- cv2.putText(
247
- annotated,
248
- f"TOTAL {obj_type.upper()}: {count}",
249
- (max(10, width - 320), y_offset),
250
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2
251
- )
252
  y_offset += 35
253
 
 
254
  frame_data = {"frame": processed_frames * process_every_nth}
255
  for name in CLASS_MAPPING.keys():
256
  frame_data[name.lower()] = frame_counts.get(name.lower(), 0)
@@ -261,7 +319,7 @@ def process_video(video_path: str, selected_class_ids: list[int], model_path: st
261
 
262
  progress = min(processed_frames / max_frames, 1.0)
263
  progress_bar.progress(progress)
264
- status_text.text(f"Analyzing Frame {frame_idx}/{total_frames or 'unknown'} (Processed {processed_frames})")
265
 
266
  cap.release()
267
  out.release()
@@ -272,40 +330,35 @@ def process_video(video_path: str, selected_class_ids: list[int], model_path: st
272
 
273
  return output_path
274
 
275
- # --- UI ---
276
 
 
277
  tab1, tab2, tab3 = st.tabs(["πŸ“Ή Video input", "πŸ“Š Analysis & results", "ℹ️ Documentation"])
278
 
279
  with tab1:
280
  col1, col2 = st.columns(2)
281
- video_path = None
282
 
283
  with col1:
284
  st.subheader("πŸ“ Upload video file")
285
- uploaded_file = st.file_uploader(
286
- "Choose a video file",
287
- type=["mp4", "avi", "mov", "mkv"],
288
- help="Supported formats. For large files, consider shorter clips."
289
- )
290
  if uploaded_file is not None:
291
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
292
  tfile.write(uploaded_file.getbuffer())
 
293
  video_path = tfile.name
294
  st.info(f"Video ready: {uploaded_file.name}")
295
  st.video(uploaded_file)
296
 
297
  with col2:
298
- st.subheader("πŸŽ₯ Process a YouTube link")
299
- youtube_url = st.text_input(
300
- "Enter a YouTube video URL",
301
- placeholder="https://www.youtube.com/watch?v=..."
302
- )
303
- if st.button("⬇️ Download from YouTube", use_container_width=True) and youtube_url:
304
- st.info("Downloading YouTube video...")
305
- yt_path = download_youtube_video(youtube_url)
306
- if yt_path:
307
- video_path = yt_path
308
- st.success("Video downloaded and ready for processing.")
309
  try:
310
  cap = cv2.VideoCapture(video_path)
311
  ret, frame = cap.read()
@@ -315,7 +368,30 @@ with tab1:
315
  except Exception:
316
  st.warning("Could not display video preview.")
317
  else:
318
- st.error("Failed to download the YouTube video. Try another link or check your network.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  st.markdown("---")
321
 
@@ -327,12 +403,15 @@ with tab1:
327
  else:
328
  try:
329
  with st.spinner(f"Analyzing video with {model_name}..."):
330
- process_video(video_path, selected_class_ids, model_name)
331
- st.success("Analysis complete! See results in the 'Analysis & results' tab.")
 
 
 
332
  except Exception as e:
333
  st.error(f"An error occurred during video processing: {e}")
334
  else:
335
- st.info("Upload a video or provide a YouTube link to begin.")
336
 
337
  with tab2:
338
  data = st.session_state.processed_data
@@ -343,17 +422,14 @@ with tab2:
343
 
344
  with col1:
345
  st.subheader("πŸŽ₯ Analyzed video output")
346
- with open(data["processed_video"], "rb") as video_file:
347
- video_bytes = video_file.read()
348
- st.video(video_bytes)
349
-
350
- st.download_button(
351
- label="πŸ“₯ Download annotated video (MP4)",
352
- data=video_bytes,
353
- file_name="analyzed_tracking_video.mp4",
354
- mime="video/mp4",
355
- use_container_width=True
356
- )
357
 
358
  with col2:
359
  st.subheader("βœ… Object crossing totals")
@@ -366,58 +442,43 @@ with tab2:
366
  st.subheader("πŸ“Š Object presence over processed frames")
367
  if data["frame_counts"]:
368
  df = pd.DataFrame(data["frame_counts"]).fillna(0)
369
-
370
  fig = go.Figure()
371
  for column in df.columns:
372
  if column != "frame":
373
- fig.add_trace(go.Scatter(
374
- x=df["frame"],
375
- y=df[column],
376
- name=column.capitalize(),
377
- mode="lines+markers"
378
- ))
379
-
380
- fig.update_layout(
381
- title="Count of objects present per processed frame",
382
- xaxis_title="Frame number (processed frames)",
383
- yaxis_title="Instance count",
384
- hovermode="x unified",
385
- height=400
386
- )
387
  st.plotly_chart(fig, use_container_width=True)
388
 
389
  st.subheader("Data export")
390
  st.dataframe(df.tail(10), use_container_width=True, height=200)
391
-
392
  csv = df.to_csv(index=False).encode("utf-8")
393
- st.download_button(
394
- label="⬇️ Download frame-by-frame data (CSV)",
395
- data=csv,
396
- file_name="object_count_data.csv",
397
- mime="text/csv",
398
- )
399
  else:
400
  st.warning("No tracking data available. Process a video first.")
401
  else:
402
  st.info("Process a video in the 'Video input' tab to view analysis results.")
403
 
404
  with tab3:
405
- st.header("Documentation: Smart Object Traffic Analyzer")
406
- st.markdown("""
407
- This app supports local videos and YouTube links. It uses YOLOv8 detection and ByteTrack tracking to count unique crossings across a vertical line.
408
-
409
- ### πŸ”‘ Core technology
410
- - YOLOv8 for detection (`yolov8n.pt` for speed, `yolov8s.pt` for accuracy).
411
- - ByteTrack via Ultralytics for robust multi-object tracking and persistent IDs.
412
- - Streamlit front-end for interactive, shareable demos.
413
- - yt-dlp for downloading YouTube sources.
414
-
415
- ### βš™οΈ Counting logic
416
- - Persistent Track IDs across frames.
417
- - Centroid-based crossing detection against a vertical line.
418
- - Each object is counted once when crossing from either side.
419
-
420
- ### πŸš€ Tips for Spaces
421
- - If runtime is limited, reduce `max_frames` or increase `process_every_nth`.
422
- - Choose `yolov8n.pt` for faster turnaround.
423
- """)
 
 
1
+ import os
 
2
  import tempfile
 
 
 
3
  from collections import defaultdict
4
+ from typing import Optional, Tuple, List
 
5
 
6
+ import cv2
7
+ import numpy as np
8
+ import pandas as pd
9
+ import plotly.graph_objects as go
10
+ import requests
11
+ import streamlit as st
12
+ from ultralytics import YOLO
13
 
14
+ # Try to import yt_dlp; if not available, we will show a helpful message when user tries YouTube
15
+ try:
16
+ import yt_dlp # type: ignore
17
+ _YT_DLP_AVAILABLE = True
18
+ except Exception:
19
+ _YT_DLP_AVAILABLE = False
20
 
21
+ # --- Page config ---
22
+ st.set_page_config(page_title="YOLOv8 Object Tracking & Counter", page_icon="πŸ€–", layout="wide")
 
 
 
23
 
24
  st.title("🚦 Smart Object Traffic Analyzer (YOLOv8)")
25
+ st.markdown(
26
+ """
27
+ Process local videos, direct public video URLs, or YouTube links to track and count unique object crossings.
28
+ Uses YOLOv8 detection and ByteTrack (when available) for robust multi-object tracking.
29
+ """
30
+ )
31
 
32
+ # --- Class mappings (subset of COCO) ---
33
  COCO_CLASS_NAMES = {
34
  0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
35
  5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light"
 
44
  "Truck": 7,
45
  }
46
 
47
+ # --- Session state initialization ---
48
  if "processed_data" not in st.session_state:
49
  st.session_state.processed_data = {
50
  "total_counts": defaultdict(int),
51
  "frame_counts": [],
52
  "processed_video": None,
53
  "processing_complete": False,
54
+ "tracked_objects": {},
55
  }
56
 
57
+ # --- Sidebar: configuration ---
 
58
  with st.sidebar:
59
+ st.header("βš™οΈ Configuration")
60
 
61
  st.subheader("Model & detection")
62
+ model_name = st.selectbox("Select YOLO model", options=["yolov8n.pt", "yolov8s.pt"],
63
+ help="Nano (n) is fast; Small (s) is more accurate.")
 
 
 
64
 
65
+ confidence = st.slider("Detection confidence threshold", min_value=0.1, max_value=1.0,
66
+ value=0.40, step=0.05, help="Minimum confidence to consider a detection valid.")
 
 
 
67
 
68
  st.subheader("Objects for counting")
69
  selected_classes_ui = {}
70
+ for name in CLASS_MAPPING.keys():
71
  default_val = name in ["Person", "Car"]
72
  selected_classes_ui[name] = st.checkbox(name, value=default_val)
73
 
74
  st.subheader("Counting line settings")
75
  show_line = st.checkbox("Show crossing line", value=True)
76
+ line_position = st.slider("Line position (vertical % from left)", min_value=10, max_value=90, value=50,
77
+ help="Place the vertical counting line as a percentage of frame width.")
 
 
 
78
 
79
  st.subheader("Performance options")
80
+ process_every_nth = st.slider("Frame skip (process every Nth frame)", min_value=1, max_value=10, value=2,
81
+ help="Higher values speed up processing but reduce tracking smoothness.")
 
 
 
82
 
83
+ max_frames = st.number_input("Maximum frames to analyze", min_value=10, max_value=5000, value=500,
84
+ help="Limit processing for long videos. Increase for full videos.")
 
 
 
85
 
86
  # --- Helpers ---
87
 
88
  @st.cache_resource
89
  def load_model(model_path: str):
90
+ """Load and cache YOLO model."""
91
  return YOLO(model_path)
92
 
93
+
94
+ def get_selected_class_ids() -> List[int]:
95
+ """Return list of selected COCO class IDs."""
96
+ return [CLASS_MAPPING[name] for name, selected in selected_classes_ui.items() if selected]
97
+
98
 
99
  @st.cache_data
100
+ def download_direct_url(url: str, timeout: int = 30) -> Tuple[Optional[str], Optional[str]]:
101
  """
102
+ Download a direct video URL (mp4/mov/etc.) to a temporary file.
103
+ Returns (file_path, error_message). On success error_message is None.
104
  """
105
  try:
106
+ resp = requests.get(url, stream=True, timeout=timeout)
107
+ resp.raise_for_status()
108
+
109
+ content_type = resp.headers.get("Content-Type", "")
110
+ suffix = ".mp4" if "mp4" in content_type.lower() or url.lower().endswith(".mp4") else ".mp4"
111
+
112
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
113
+ for chunk in resp.iter_content(chunk_size=8192):
114
+ if not chunk:
115
+ continue
116
+ temp_file.write(chunk)
117
+ temp_file.close()
118
+ return temp_file.name, None
119
+ except requests.exceptions.RequestException as e:
120
+ return None, f"Failed to download direct URL: {e}. Check the URL and network access."
121
+ except Exception as e:
122
+ return None, f"Unexpected error while downloading direct URL: {e}"
123
+
124
+
125
+ @st.cache_data
126
+ def download_youtube_video(youtube_url: str) -> Tuple[Optional[str], Optional[str]]:
127
+ """
128
+ Attempt to download a YouTube video using yt-dlp.
129
+ Returns (file_path, error_message). If download succeeds, error_message is None.
130
+ """
131
+ if not _YT_DLP_AVAILABLE:
132
+ return None, "yt-dlp is not available in this environment. Install yt-dlp or use a direct URL / upload."
133
+
134
+ try:
135
  temp_dir = tempfile.mkdtemp()
136
  output_template = os.path.join(temp_dir, "video.%(ext)s")
137
 
 
138
  ydl_opts = {
139
+ "format": "best[ext=mp4]/best",
140
  "outtmpl": output_template,
141
  "noplaylist": True,
142
  "quiet": True,
143
  "no_warnings": True,
144
+ "retries": 2,
 
145
  "merge_output_format": "mp4",
146
  }
147
 
148
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
149
  info = ydl.extract_info(youtube_url, download=True)
 
150
  filename = ydl.prepare_filename(info)
151
+ # prefer .mp4 if merged
152
  if not filename.endswith(".mp4"):
153
  mp4_candidate = os.path.splitext(filename)[0] + ".mp4"
154
  if os.path.exists(mp4_candidate):
155
  filename = mp4_candidate
156
  if os.path.exists(filename):
157
+ return filename, None
158
  else:
159
+ return None, "Download completed but output file not found."
160
+ except yt_dlp.utils.DownloadError as e:
161
+ # Likely network or availability issue
162
+ guidance = (
163
+ "yt-dlp failed to download the YouTube video. This can happen if the runtime has no outbound network access "
164
+ "or YouTube is blocked. Alternatives:\n"
165
+ "β€’ Upload the video file directly using the uploader.\n"
166
+ "β€’ Provide a direct public MP4 URL (use the Direct URL option).\n"
167
+ "β€’ Host the video in the Space repository or on the Hugging Face Hub and provide the path.\n"
168
+ "β€’ Run the app locally where internet access is available."
169
+ )
170
+ return None, f"{e}\n\n{guidance}"
171
  except Exception as e:
172
+ return None, f"Unexpected error while downloading YouTube video: {e}"
 
173
 
 
174
 
175
+ # --- Core processing function ---
176
+
177
+ def process_video(video_path: str, selected_class_ids: List[int], model_path: str) -> Optional[str]:
178
+ """
179
+ Process the video, perform detection + tracking, count crossings, and write an annotated output video.
180
+ Returns path to annotated video on success, otherwise None.
181
+ """
182
  model = load_model(model_path)
183
+
184
  cap = cv2.VideoCapture(video_path)
185
+ if not cap.isOpened():
186
+ st.error("Could not open the video file. The file may be corrupted or in an unsupported format.")
187
+ return None
188
 
189
  fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
190
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
 
192
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
193
 
194
  if total_frames > max_frames:
195
+ st.warning(f"Video will be processed for the first {max_frames} frames only (sidebar setting).")
196
 
197
  temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
198
  output_path = temp_output.name
199
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
200
+ out = cv2.VideoWriter(output_path, fourcc, max(int(fps / process_every_nth), 1), (width, height))
201
 
202
  state = st.session_state.processed_data
203
  state["total_counts"] = defaultdict(int)
 
221
  if frame_idx % process_every_nth != 0:
222
  continue
223
 
224
+ # Run YOLOv8 tracking (ByteTrack if available in ultralytics)
225
+ try:
226
+ results = model.track(
227
+ frame,
228
+ conf=confidence,
229
+ classes=selected_class_ids if selected_class_ids else None,
230
+ persist=True,
231
+ tracker="bytetrack.yaml",
232
+ verbose=False
233
+ )
234
+ except Exception:
235
+ # Fallback to detection-only if tracker config not available
236
+ results = model(frame, conf=confidence, classes=selected_class_ids if selected_class_ids else None)
237
 
238
  annotated = frame.copy()
239
  frame_counts = defaultdict(int)
240
 
241
+ # Parse results (works for both track and detect outputs)
242
+ if results and hasattr(results[0], "boxes"):
243
+ boxes_obj = results[0].boxes
244
+ # Some detect-only outputs may not have ids
245
+ ids_attr = getattr(boxes_obj, "id", None)
246
+ try:
247
+ boxes = boxes_obj.xyxy.cpu().numpy().astype(int)
248
+ class_ids = boxes_obj.cls.cpu().numpy().astype(int)
249
+ except Exception:
250
+ boxes = []
251
+ class_ids = []
252
+
253
+ ids = None
254
+ if ids_attr is not None:
255
+ try:
256
+ ids = ids_attr.cpu().numpy().astype(int)
257
+ except Exception:
258
+ ids = None
259
+
260
+ if len(boxes) > 0:
261
+ for i, box in enumerate(boxes):
262
+ x1, y1, x2, y2 = box
263
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
264
+ cls_id = int(class_ids[i]) if i < len(class_ids) else -1
265
+ cls_name = COCO_CLASS_NAMES.get(cls_id, "Unknown")
266
+ frame_counts[cls_name.lower()] += 1
267
+
268
+ track_id = int(ids[i]) if (ids is not None and i < len(ids)) else None
269
+
270
+ if track_id is None:
271
+ # Use a synthetic id based on bbox and frame to avoid counting duplicates across frames
272
+ track_id = hash((x1, y1, x2, y2, frame_idx)) & 0x7FFFFFFF
273
+
274
+ if track_id not in state["tracked_objects"]:
275
+ state["tracked_objects"][track_id] = {
276
+ "class": cls_name,
277
+ "last_centroid": (cx, cy),
278
+ "counted": False
279
+ }
280
+ else:
281
+ obj = state["tracked_objects"][track_id]
282
+ prev_x = obj["last_centroid"][0]
283
+ if not obj["counted"]:
284
+ crossed_right = prev_x < line_x and cx >= line_x
285
+ crossed_left = prev_x > line_x and cx <= line_x
286
+ if crossed_right or crossed_left:
287
+ state["total_counts"][cls_name] += 1
288
+ obj["counted"] = True
289
+ obj["last_centroid"] = (cx, cy)
290
+
291
+ # Draw annotations
292
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 0), 2)
293
+ cv2.circle(annotated, (cx, cy), 5, (0, 0, 255), -1)
294
+ label = f"ID:{track_id} {cls_name}"
295
+ cv2.putText(annotated, label, (x1, max(10, y1 - 10)),
296
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
297
+
298
+ # Draw counting line and totals
299
  if show_line:
300
  line_color = (0, 255, 255)
301
  cv2.line(annotated, (line_x, 0), (line_x, height), line_color, 2)
302
+ cv2.putText(annotated, "COUNTING LINE", (min(width - 180, line_x + 5), 20),
303
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, line_color, 2)
 
 
304
 
305
  y_offset = 30
306
  for obj_type, count in state["total_counts"].items():
307
+ cv2.putText(annotated, f"TOTAL {obj_type.upper()}: {count}", (max(10, width - 320), y_offset),
308
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
 
 
 
 
309
  y_offset += 35
310
 
311
+ # Save frame counts
312
  frame_data = {"frame": processed_frames * process_every_nth}
313
  for name in CLASS_MAPPING.keys():
314
  frame_data[name.lower()] = frame_counts.get(name.lower(), 0)
 
319
 
320
  progress = min(processed_frames / max_frames, 1.0)
321
  progress_bar.progress(progress)
322
+ status_text.text(f"Analyzing frame {frame_idx}/{total_frames or 'unknown'} (Processed {processed_frames})")
323
 
324
  cap.release()
325
  out.release()
 
330
 
331
  return output_path
332
 
 
333
 
334
+ # --- UI layout: tabs ---
335
  tab1, tab2, tab3 = st.tabs(["πŸ“Ή Video input", "πŸ“Š Analysis & results", "ℹ️ Documentation"])
336
 
337
  with tab1:
338
  col1, col2 = st.columns(2)
339
+ video_path: Optional[str] = None
340
 
341
  with col1:
342
  st.subheader("πŸ“ Upload video file")
343
+ uploaded_file = st.file_uploader("Choose a video file", type=["mp4", "avi", "mov", "mkv"],
344
+ help="Supported formats. For large files, consider shorter clips.")
 
 
 
345
  if uploaded_file is not None:
346
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
347
  tfile.write(uploaded_file.getbuffer())
348
+ tfile.close()
349
  video_path = tfile.name
350
  st.info(f"Video ready: {uploaded_file.name}")
351
  st.video(uploaded_file)
352
 
353
  with col2:
354
+ st.subheader("🌐 Direct public video URL")
355
+ direct_url = st.text_input("Enter a direct public video URL (e.g., .mp4)", placeholder="https://example.com/video.mp4")
356
+ if st.button("⬇️ Download from URL", use_container_width=True) and direct_url:
357
+ st.info("Attempting to download the direct video URL...")
358
+ path, err = download_direct_url(direct_url)
359
+ if path:
360
+ video_path = path
361
+ st.success("Direct URL downloaded and ready for processing.")
 
 
 
362
  try:
363
  cap = cv2.VideoCapture(video_path)
364
  ret, frame = cap.read()
 
368
  except Exception:
369
  st.warning("Could not display video preview.")
370
  else:
371
+ st.error(err)
372
+
373
+ st.markdown("---")
374
+ st.subheader("πŸŽ₯ YouTube link (optional)")
375
+ youtube_url = st.text_input("Enter a YouTube video URL", placeholder="https://www.youtube.com/watch?v=...")
376
+ if st.button("⬇️ Download from YouTube", use_container_width=True) and youtube_url:
377
+ if not _YT_DLP_AVAILABLE:
378
+ st.error("yt-dlp is not installed in this environment. Use a direct URL or upload the file.")
379
+ else:
380
+ st.info("Attempting to download YouTube video...")
381
+ path, err = download_youtube_video(youtube_url)
382
+ if path:
383
+ video_path = path
384
+ st.success("YouTube video downloaded and ready for processing.")
385
+ try:
386
+ cap = cv2.VideoCapture(video_path)
387
+ ret, frame = cap.read()
388
+ if ret:
389
+ st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), caption="Video preview", use_column_width=True)
390
+ cap.release()
391
+ except Exception:
392
+ st.warning("Could not display video preview.")
393
+ else:
394
+ st.error(err)
395
 
396
  st.markdown("---")
397
 
 
403
  else:
404
  try:
405
  with st.spinner(f"Analyzing video with {model_name}..."):
406
+ out_path = process_video(video_path, selected_class_ids, model_name)
407
+ if out_path:
408
+ st.success("Analysis complete! See results in the 'Analysis & results' tab.")
409
+ else:
410
+ st.error("Processing failed. Check the logs and input file.")
411
  except Exception as e:
412
  st.error(f"An error occurred during video processing: {e}")
413
  else:
414
+ st.info("Upload a video, provide a direct URL, or a YouTube link to begin.")
415
 
416
  with tab2:
417
  data = st.session_state.processed_data
 
422
 
423
  with col1:
424
  st.subheader("πŸŽ₯ Analyzed video output")
425
+ try:
426
+ with open(data["processed_video"], "rb") as video_file:
427
+ video_bytes = video_file.read()
428
+ st.video(video_bytes)
429
+ st.download_button(label="πŸ“₯ Download annotated video (MP4)", data=video_bytes,
430
+ file_name="analyzed_tracking_video.mp4", mime="video/mp4", use_container_width=True)
431
+ except Exception:
432
+ st.error("Could not load the processed video file.")
 
 
 
433
 
434
  with col2:
435
  st.subheader("βœ… Object crossing totals")
 
442
  st.subheader("πŸ“Š Object presence over processed frames")
443
  if data["frame_counts"]:
444
  df = pd.DataFrame(data["frame_counts"]).fillna(0)
 
445
  fig = go.Figure()
446
  for column in df.columns:
447
  if column != "frame":
448
+ fig.add_trace(go.Scatter(x=df["frame"], y=df[column], name=column.capitalize(), mode="lines+markers"))
449
+ fig.update_layout(title="Count of objects present per processed frame",
450
+ xaxis_title="Frame number (processed frames)",
451
+ yaxis_title="Instance count", hovermode="x unified", height=400)
 
 
 
 
 
 
 
 
 
 
452
  st.plotly_chart(fig, use_container_width=True)
453
 
454
  st.subheader("Data export")
455
  st.dataframe(df.tail(10), use_container_width=True, height=200)
 
456
  csv = df.to_csv(index=False).encode("utf-8")
457
+ st.download_button(label="⬇️ Download frame-by-frame data (CSV)", data=csv,
458
+ file_name="object_count_data.csv", mime="text/csv")
 
 
 
 
459
  else:
460
  st.warning("No tracking data available. Process a video first.")
461
  else:
462
  st.info("Process a video in the 'Video input' tab to view analysis results.")
463
 
464
  with tab3:
465
+ st.header("Documentation & Notes")
466
+ st.markdown(
467
+ """
468
+ **Supported inputs**
469
+ - Local upload (recommended for Spaces demos).
470
+ - Direct public video URL (MP4 preferred).
471
+ - YouTube link (requires `yt-dlp` and outbound network access).
472
+
473
+ **Why YouTube downloads may fail in Spaces**
474
+ Hugging Face Spaces may restrict outbound network access or DNS resolution. If YouTube download fails, use a direct URL or upload the file. Running the app locally will allow YouTube downloads if your machine has internet access.
475
+
476
+ **Performance tips**
477
+ - Use `yolov8n.pt` for faster processing.
478
+ - Increase `Frame skip` (process every Nth frame) to speed up long videos.
479
+ - Reduce `Maximum frames` for quick demos.
480
+
481
+ **System packages**
482
+ This app uses `opencv-python-headless` to avoid GUI dependencies. You generally do not need a `setup.sh` that installs `libgl1-mesa-glx` or `libglib2.0-0`. Remove `setup.sh` unless you switch to non-headless OpenCV or require specific system libraries.
483
+ """
484
+ )