Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- __pycache__/app.cpython-314.pyc +0 -0
- app.py +399 -406
- audio/__init__.py +8 -0
- audio/__pycache__/__init__.cpython-314.pyc +0 -0
- audio/__pycache__/augmentation.cpython-314.pyc +0 -0
- audio/__pycache__/encoder.cpython-314.pyc +0 -0
- audio/__pycache__/preprocessor.cpython-314.pyc +0 -0
- audio/__pycache__/sam_audio.cpython-314.pyc +0 -0
- audio/augmentation.py +464 -0
- audio/encoder.py +424 -0
- audio/preprocessor.py +359 -0
- audio/sam_audio.py +480 -0
- data/__init__.py +6 -0
- data/__pycache__/__init__.cpython-314.pyc +0 -0
- data/__pycache__/species_db.cpython-314.pyc +0 -0
- data/species_db.py +582 -0
- llm/__init__.py +7 -0
- llm/__pycache__/__init__.cpython-314.pyc +0 -0
- llm/__pycache__/ollama_client.cpython-314.pyc +0 -0
- llm/__pycache__/reasoning.cpython-314.pyc +0 -0
- llm/__pycache__/zero_shot_identifier.cpython-314.pyc +0 -0
- llm/ollama_client.py +254 -0
- llm/reasoning.py +405 -0
- llm/zero_shot_identifier.py +457 -0
- models/__init__.py +7 -0
- models/__pycache__/__init__.cpython-314.pyc +0 -0
- models/__pycache__/audio_classifier.cpython-314.pyc +0 -0
- models/__pycache__/novelty_detector.cpython-314.pyc +0 -0
- models/audio_classifier.py +307 -0
- models/novelty_detector.py +334 -0
- requirements.txt +8 -4
__pycache__/app.cpython-314.pyc
ADDED
|
Binary file (27.7 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,17 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
🐦 BirdSense Pro - AI Bird Identification
|
| 3 |
-
Uses LOCAL Ollama LLM for TRUE zero-shot identification
|
| 4 |
|
| 5 |
-
|
| 6 |
-
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
1. Audio → LLM Analysis → Bird ID (zero-shot, 10,000+ species)
|
| 11 |
-
2. Image → LLM Vision → Bird ID
|
| 12 |
-
3. Description → LLM → Bird ID
|
| 13 |
-
4. Streaming responses
|
| 14 |
-
5. Multi-bird detection
|
| 15 |
|
| 16 |
CSCR Initiative
|
| 17 |
"""
|
|
@@ -21,23 +15,154 @@ import numpy as np
|
|
| 21 |
import scipy.signal as signal
|
| 22 |
from scipy.ndimage import gaussian_filter1d
|
| 23 |
from dataclasses import dataclass
|
| 24 |
-
from typing import Optional, Tuple,
|
| 25 |
import json
|
| 26 |
-
import os
|
| 27 |
import requests
|
| 28 |
-
import time
|
| 29 |
|
| 30 |
# ================== CONFIG ==================
|
| 31 |
SAMPLE_RATE = 48000
|
| 32 |
-
|
| 33 |
-
# Ollama configuration (LOCAL - primary)
|
| 34 |
OLLAMA_URL = "http://localhost:11434"
|
| 35 |
-
OLLAMA_MODEL = "qwen2.5:3b"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
|
|
|
| 41 |
BIRD_IMAGES = {
|
| 42 |
"Asian Koel": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/78/Eudynamys_scolopaceus_-_Koel_male_-_Sukhna_Lake%2C_India.jpg/320px-Eudynamys_scolopaceus_-_Koel_male_-_Sukhna_Lake%2C_India.jpg",
|
| 43 |
"Indian Cuckoo": "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6b/Cuculus_micropterus.jpg/320px-Cuculus_micropterus.jpg",
|
|
@@ -56,196 +181,126 @@ BIRD_IMAGES = {
|
|
| 56 |
"Greater Coucal": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d6/Greater_Coucal_%28Centropus_sinensis%29_in_Hyderabad%2C_AP_W_IMG_7544.jpg/320px-Greater_Coucal_%28Centropus_sinensis%29_in_Hyderabad%2C_AP_W_IMG_7544.jpg",
|
| 57 |
"Common Tailorbird": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Common_Tailorbird_%28Orthotomus_sutorius%29_in_Kolkata_I_IMG_2859.jpg/320px-Common_Tailorbird_%28Orthotomus_sutorius%29_in_Kolkata_I_IMG_2859.jpg",
|
| 58 |
"Green Bee-eater": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b1/Merops_orientalis_%28Pune%2C_India%29.jpg/320px-Merops_orientalis_%28Pune%2C_India%29.jpg",
|
| 59 |
-
"Common Hawk-Cuckoo": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/08/Hierococcyx_varius.jpg/320px-Hierococcyx_varius.jpg",
|
| 60 |
-
"Indian Robin": "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6e/Indian_Robin_%28Saxicoloides_fulicatus%29_Male.jpg/320px-Indian_Robin_%28Saxicoloides_fulicatus%29_Male.jpg",
|
| 61 |
-
"Grey Francolin": "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8c/Grey_francolin_%28Francolinus_pondicerianus%29.jpg/320px-Grey_francolin_%28Francolinus_pondicerianus%29.jpg",
|
| 62 |
}
|
| 63 |
DEFAULT_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/45/Eopsaltria_australis_-_Mogo_Campground.jpg/320px-Eopsaltria_australis_-_Mogo_Campground.jpg"
|
| 64 |
|
| 65 |
|
| 66 |
-
# ==================
|
| 67 |
|
| 68 |
-
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def generate(self, prompt: str, system: str = None, stream: bool = False) -> str:
|
| 89 |
-
"""Generate response from Ollama."""
|
| 90 |
-
payload = {
|
| 91 |
-
"model": self.model,
|
| 92 |
-
"prompt": prompt,
|
| 93 |
-
"stream": stream,
|
| 94 |
-
"options": {
|
| 95 |
-
"temperature": 0.3,
|
| 96 |
-
"num_predict": 1500
|
| 97 |
-
}
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
if system:
|
| 101 |
-
payload["system"] = system
|
| 102 |
-
|
| 103 |
-
try:
|
| 104 |
-
if stream:
|
| 105 |
-
return self._generate_stream(payload)
|
| 106 |
-
else:
|
| 107 |
-
resp = requests.post(
|
| 108 |
-
f"{self.base_url}/api/generate",
|
| 109 |
-
json=payload,
|
| 110 |
-
timeout=120
|
| 111 |
-
)
|
| 112 |
-
if resp.status_code == 200:
|
| 113 |
-
return resp.json().get("response", "")
|
| 114 |
-
return None
|
| 115 |
-
except Exception as e:
|
| 116 |
-
print(f"Ollama error: {e}")
|
| 117 |
-
return None
|
| 118 |
-
|
| 119 |
-
def _generate_stream(self, payload) -> Generator[str, None, None]:
|
| 120 |
-
"""Stream response from Ollama."""
|
| 121 |
-
try:
|
| 122 |
-
with requests.post(
|
| 123 |
-
f"{self.base_url}/api/generate",
|
| 124 |
-
json=payload,
|
| 125 |
-
stream=True,
|
| 126 |
-
timeout=120
|
| 127 |
-
) as resp:
|
| 128 |
-
for line in resp.iter_lines():
|
| 129 |
-
if line:
|
| 130 |
-
data = json.loads(line)
|
| 131 |
-
if "response" in data:
|
| 132 |
-
yield data["response"]
|
| 133 |
-
if data.get("done"):
|
| 134 |
-
break
|
| 135 |
-
except Exception as e:
|
| 136 |
-
yield f"Error: {e}"
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Global Ollama client
|
| 140 |
-
ollama = OllamaClient()
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
def call_llm(prompt: str, system: str = None, stream: bool = False):
|
| 144 |
-
"""
|
| 145 |
-
Call LLM - tries Ollama first (local), falls back to HuggingFace API.
|
| 146 |
-
"""
|
| 147 |
-
# Try Ollama first (local, fast)
|
| 148 |
-
if ollama.is_available():
|
| 149 |
-
result = ollama.generate(prompt, system, stream=stream)
|
| 150 |
-
if result:
|
| 151 |
-
return result
|
| 152 |
|
| 153 |
-
# Fallback to HuggingFace API
|
| 154 |
try:
|
| 155 |
-
|
| 156 |
-
if
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
if resp.status_code == 200:
|
| 172 |
-
result = resp.json()
|
| 173 |
-
if isinstance(result, list) and len(result) > 0:
|
| 174 |
return result[0].get("generated_text", "")
|
| 175 |
except Exception as e:
|
| 176 |
-
print(f"
|
| 177 |
-
|
| 178 |
return None
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
def get_llm_status() -> str:
|
| 182 |
-
"""Get
|
| 183 |
-
if
|
| 184 |
-
return f"🟢 Ollama ({OLLAMA_MODEL})
|
| 185 |
else:
|
| 186 |
-
return "🟡 HuggingFace API
|
| 187 |
|
| 188 |
|
| 189 |
# ================== AUDIO FEATURES ==================
|
| 190 |
|
| 191 |
-
@dataclass
|
| 192 |
class AudioFeatures:
|
| 193 |
-
"""Audio features
|
| 194 |
duration: float
|
| 195 |
peak_frequency: float
|
| 196 |
freq_range: Tuple[float, float]
|
| 197 |
-
spectral_centroid: float
|
| 198 |
num_syllables: int
|
| 199 |
syllable_rate: float
|
| 200 |
is_melodic: bool
|
| 201 |
is_repetitive: bool
|
| 202 |
-
amplitude_pattern: str
|
| 203 |
snr_db: float
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
return f"""Audio
|
| 210 |
-
- Duration: {self.duration:.1f}
|
| 211 |
-
-
|
| 212 |
- Frequency range: {self.freq_range[0]:.0f} - {self.freq_range[1]:.0f} Hz
|
| 213 |
-
-
|
| 214 |
-
- Syllables: {self.num_syllables}
|
| 215 |
-
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
elif f < 8000: return "high - sunbird, small passerine"
|
| 226 |
-
else: return "very high - alarm call or insect-like"
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
def extract_features(audio: np.ndarray, sr: int) -> AudioFeatures:
|
| 230 |
-
"""Extract audio features."""
|
| 231 |
duration = len(audio) / sr
|
| 232 |
-
audio = audio / (np.max(np.abs(audio)) + 1e-8)
|
| 233 |
|
| 234 |
-
# Spectral
|
| 235 |
freqs, psd = signal.welch(audio, sr, nperseg=min(4096, len(audio)))
|
| 236 |
peak_freq = freqs[np.argmax(psd)]
|
| 237 |
cumsum = np.cumsum(psd) / (np.sum(psd) + 1e-10)
|
| 238 |
freq_low = freqs[np.searchsorted(cumsum, 0.10)]
|
| 239 |
freq_high = freqs[np.searchsorted(cumsum, 0.90)]
|
| 240 |
-
centroid = np.sum(freqs * psd) / (np.sum(psd) + 1e-10)
|
| 241 |
|
| 242 |
-
#
|
| 243 |
envelope = np.abs(signal.hilbert(audio))
|
| 244 |
k = int(0.02 * sr)
|
| 245 |
if k > 0:
|
| 246 |
envelope = gaussian_filter1d(envelope, k)
|
| 247 |
|
| 248 |
-
# Syllables
|
| 249 |
n_fft, hop = 2048, 512
|
| 250 |
_, _, Zxx = signal.stft(audio, sr, nperseg=n_fft, noverlap=n_fft-hop)
|
| 251 |
flux = np.sum(np.maximum(0, np.diff(np.abs(Zxx), axis=1)), axis=0)
|
|
@@ -257,29 +312,15 @@ def extract_features(audio: np.ndarray, sr: int) -> AudioFeatures:
|
|
| 257 |
num_syl = len(peaks)
|
| 258 |
syl_rate = num_syl / duration if duration > 0 else 0
|
| 259 |
|
| 260 |
-
# Melodic
|
| 261 |
is_melodic = False
|
| 262 |
if len(audio) > sr:
|
| 263 |
chunks = np.array_split(audio, min(20, max(5, int(duration*4))))
|
| 264 |
-
chunk_freqs = []
|
| 265 |
-
|
| 266 |
-
if len(c) > 1024:
|
| 267 |
-
f, p = signal.welch(c, sr, nperseg=1024)
|
| 268 |
-
chunk_freqs.append(f[np.argmax(p)])
|
| 269 |
if chunk_freqs:
|
| 270 |
is_melodic = np.std(chunk_freqs) / (np.mean(chunk_freqs) + 1e-10) > 0.15
|
| 271 |
|
| 272 |
-
# Amplitude pattern
|
| 273 |
-
amp_pattern = "unknown"
|
| 274 |
-
if len(envelope) > 100:
|
| 275 |
-
q = len(envelope) // 4
|
| 276 |
-
s, e = np.mean(envelope[:q]), np.mean(envelope[-q:])
|
| 277 |
-
v = np.std(envelope) / (np.mean(envelope) + 1e-10)
|
| 278 |
-
if v > 0.6: amp_pattern = "varied"
|
| 279 |
-
elif e > s * 1.3: amp_pattern = "ascending"
|
| 280 |
-
elif e < s * 0.7: amp_pattern = "descending"
|
| 281 |
-
else: amp_pattern = "steady"
|
| 282 |
-
|
| 283 |
# SNR
|
| 284 |
noise = np.percentile(np.abs(audio), 5)
|
| 285 |
sig = np.percentile(np.abs(audio), 95)
|
|
@@ -289,18 +330,17 @@ def extract_features(audio: np.ndarray, sr: int) -> AudioFeatures:
|
|
| 289 |
duration=duration,
|
| 290 |
peak_frequency=float(peak_freq),
|
| 291 |
freq_range=(float(freq_low), float(freq_high)),
|
| 292 |
-
spectral_centroid=float(centroid),
|
| 293 |
num_syllables=num_syl,
|
| 294 |
syllable_rate=float(syl_rate),
|
| 295 |
is_melodic=is_melodic,
|
| 296 |
is_repetitive=syl_rate > 3,
|
| 297 |
-
|
| 298 |
-
|
| 299 |
)
|
| 300 |
|
| 301 |
|
| 302 |
def preprocess_audio(audio_data: np.ndarray, sr: int) -> Tuple[np.ndarray, int]:
|
| 303 |
-
"""
|
| 304 |
if audio_data.dtype == np.int16:
|
| 305 |
audio_data = audio_data.astype(np.float32) / 32768.0
|
| 306 |
elif audio_data.dtype == np.int32:
|
|
@@ -317,93 +357,78 @@ def preprocess_audio(audio_data: np.ndarray, sr: int) -> Tuple[np.ndarray, int]:
|
|
| 317 |
sr = SAMPLE_RATE
|
| 318 |
|
| 319 |
audio_data = audio_data / (np.max(np.abs(audio_data)) + 1e-8)
|
| 320 |
-
|
| 321 |
-
# Bandpass
|
| 322 |
-
nyq = sr / 2
|
| 323 |
-
low, high = 150 / nyq, min(15000 / nyq, 0.99)
|
| 324 |
-
b, a = signal.butter(4, [low, high], btype='band')
|
| 325 |
-
audio_data = signal.filtfilt(b, a, audio_data)
|
| 326 |
-
|
| 327 |
return audio_data, sr
|
| 328 |
|
| 329 |
|
| 330 |
# ================== LLM PROMPTS ==================
|
| 331 |
|
| 332 |
-
|
| 333 |
You specialize in Indian birds (1,300+ species).
|
| 334 |
|
| 335 |
-
|
|
|
|
|
|
|
| 336 |
|
| 337 |
-
IMPORTANT
|
| 338 |
-
1.
|
| 339 |
-
2.
|
| 340 |
-
3. Consider
|
| 341 |
-
4.
|
| 342 |
|
| 343 |
-
|
| 344 |
{
|
| 345 |
"birds": [
|
| 346 |
-
{
|
| 347 |
-
"name": "Common Name",
|
| 348 |
-
"scientific_name": "Genus species",
|
| 349 |
-
"confidence": 85,
|
| 350 |
-
"reasoning": "Brief explanation of why this bird matches"
|
| 351 |
-
}
|
| 352 |
],
|
| 353 |
-
"analysis": "Overall analysis
|
| 354 |
}"""
|
| 355 |
|
| 356 |
|
| 357 |
def get_bird_image(name: str) -> str:
|
| 358 |
-
"""Get image URL
|
| 359 |
if name in BIRD_IMAGES:
|
| 360 |
return BIRD_IMAGES[name]
|
| 361 |
-
name_lower = name.lower()
|
| 362 |
for bird, url in BIRD_IMAGES.items():
|
| 363 |
-
if bird.lower() in
|
| 364 |
return url
|
| 365 |
return DEFAULT_IMAGE
|
| 366 |
|
| 367 |
|
| 368 |
-
def format_results(
|
| 369 |
-
"""
|
| 370 |
-
if not
|
| 371 |
return "### ⚠️ No response from LLM"
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
try:
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
end = llm_response.rfind('}') + 1
|
| 377 |
if start >= 0 and end > start:
|
| 378 |
-
data = json.loads(
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
return f"### 🤖 AI Analysis\n\n{llm_response}"
|
| 382 |
-
|
| 383 |
-
birds = data.get("birds", [])
|
| 384 |
-
analysis = data.get("analysis", "")
|
| 385 |
-
|
| 386 |
-
if not birds:
|
| 387 |
-
return f"### ❌ No birds identified\n\n{analysis}"
|
| 388 |
-
|
| 389 |
-
output = f"## 🐦 Birds Identified\n\n*{analysis}*\n\n"
|
| 390 |
-
|
| 391 |
-
for i, bird in enumerate(birds, 1):
|
| 392 |
-
name = bird.get("name", "Unknown")
|
| 393 |
-
scientific = bird.get("scientific_name", "")
|
| 394 |
-
conf = bird.get("confidence", 0)
|
| 395 |
-
reason = bird.get("reasoning", "")
|
| 396 |
-
|
| 397 |
-
img = get_bird_image(name)
|
| 398 |
|
| 399 |
-
if
|
| 400 |
-
|
| 401 |
-
elif conf >= 60:
|
| 402 |
-
badge = "🟡 MEDIUM"
|
| 403 |
-
else:
|
| 404 |
-
badge = "🔴 LOW"
|
| 405 |
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
---
|
| 408 |
|
| 409 |
### {i}. **{name}** ({conf}%) {badge}
|
|
@@ -412,65 +437,82 @@ def format_results(llm_response: str) -> str:
|
|
| 412 |
|
| 413 |
**Scientific Name:** _{scientific}_
|
| 414 |
|
| 415 |
-
**Why
|
| 416 |
|
| 417 |
"""
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
|
| 424 |
|
| 425 |
-
# ==================
|
| 426 |
|
| 427 |
def identify_audio(audio, location: str = "", month: str = ""):
|
| 428 |
-
"""Identify bird from audio using LLM."""
|
| 429 |
if audio is None:
|
| 430 |
-
return "### ⚠️ Please record or upload
|
| 431 |
|
| 432 |
status = get_llm_status()
|
| 433 |
-
yield f"### 🔄 Processing
|
| 434 |
|
| 435 |
try:
|
| 436 |
sr, audio_data = audio
|
| 437 |
audio_data, sr = preprocess_audio(audio_data, sr)
|
| 438 |
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
| 441 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
prompt = f"""Identify the bird(s) in this recording:
|
| 443 |
|
| 444 |
-
{features.
|
|
|
|
| 445 |
"""
|
| 446 |
if location:
|
| 447 |
-
prompt += f"
|
| 448 |
if month:
|
| 449 |
-
prompt += f"
|
| 450 |
|
| 451 |
-
|
|
|
|
| 452 |
|
| 453 |
-
|
| 454 |
|
| 455 |
-
|
|
|
|
|
|
|
| 456 |
|
| 457 |
if response:
|
| 458 |
-
result = format_results(response)
|
| 459 |
-
result += f"\n\n---\n\n### 📊 Audio Analysis\n{features.
|
| 460 |
-
result += f"\n\n**LLM:** {status}"
|
| 461 |
yield result
|
| 462 |
else:
|
| 463 |
-
yield f"""### ⚠️ LLM not
|
| 464 |
|
| 465 |
-
|
| 466 |
|
| 467 |
-
**
|
| 468 |
-
{features.
|
| 469 |
|
| 470 |
-
**To fix:**
|
| 471 |
-
1.
|
| 472 |
-
2. Pull
|
| 473 |
-
3. Try again
|
| 474 |
"""
|
| 475 |
|
| 476 |
except Exception as e:
|
|
@@ -480,49 +522,36 @@ def identify_audio(audio, location: str = "", month: str = ""):
|
|
| 480 |
def identify_description(description: str):
|
| 481 |
"""Identify bird from description using LLM."""
|
| 482 |
if not description or len(description.strip()) < 5:
|
| 483 |
-
return "### ⚠️ Please enter a description
|
| 484 |
|
| 485 |
status = get_llm_status()
|
| 486 |
-
yield f"### 🔄 Analyzing
|
| 487 |
|
| 488 |
-
prompt = f"""Identify the bird(s)
|
| 489 |
|
| 490 |
{description}
|
| 491 |
|
| 492 |
-
|
| 493 |
|
| 494 |
-
response = call_llm(prompt,
|
| 495 |
|
| 496 |
if response:
|
| 497 |
-
|
| 498 |
-
result += f"\n\n**LLM:** {status}"
|
| 499 |
-
yield result
|
| 500 |
else:
|
| 501 |
-
yield f"
|
| 502 |
-
|
| 503 |
-
**LLM Status:** {status}
|
| 504 |
-
|
| 505 |
-
**To fix:**
|
| 506 |
-
1. Make sure Ollama is running: `ollama serve`
|
| 507 |
-
2. Pull the model: `ollama pull {OLLAMA_MODEL}`
|
| 508 |
-
"""
|
| 509 |
|
| 510 |
|
| 511 |
def identify_image(image):
|
| 512 |
"""Identify bird from image using LLM."""
|
| 513 |
if image is None:
|
| 514 |
-
return "### ⚠️ Please upload
|
| 515 |
|
| 516 |
status = get_llm_status()
|
| 517 |
-
yield f"### 🔄 Analyzing image...\n\n**LLM
|
| 518 |
|
| 519 |
try:
|
| 520 |
-
if
|
| 521 |
-
img = image.numpy()
|
| 522 |
-
else:
|
| 523 |
-
img = np.array(image)
|
| 524 |
|
| 525 |
-
# Color analysis
|
| 526 |
colors = []
|
| 527 |
if len(img.shape) == 3 and img.shape[2] >= 3:
|
| 528 |
r, g, b = np.mean(img[:,:,0]), np.mean(img[:,:,1]), np.mean(img[:,:,2])
|
|
@@ -535,25 +564,21 @@ def identify_image(image):
|
|
| 535 |
|
| 536 |
color_desc = ", ".join(colors) if colors else "mixed"
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
prompt = f"""Identify the bird in this image.
|
| 541 |
|
| 542 |
-
Detected
|
| 543 |
-
Image size: {img.shape[1]}x{img.shape[0]} pixels
|
| 544 |
|
| 545 |
-
|
| 546 |
-
List all matching birds with confidence >= 50%."""
|
| 547 |
|
| 548 |
-
response = call_llm(prompt,
|
| 549 |
|
| 550 |
if response:
|
| 551 |
-
result =
|
| 552 |
-
result +=
|
| 553 |
result += f"\n\n**LLM:** {status}"
|
| 554 |
yield result
|
| 555 |
else:
|
| 556 |
-
yield f"### ⚠️ LLM not
|
| 557 |
|
| 558 |
except Exception as e:
|
| 559 |
yield f"### ❌ Error: {str(e)}"
|
|
@@ -561,143 +586,111 @@ List all matching birds with confidence >= 50%."""
|
|
| 561 |
|
| 562 |
# ================== GRADIO UI ==================
|
| 563 |
|
| 564 |
-
with gr.Blocks(title="🐦 BirdSense Pro
|
| 565 |
|
| 566 |
gr.HTML("""
|
| 567 |
-
<div style="text-align: center; background: linear-gradient(135deg, #1a4d2e 0%, #2d5a3e 50%, #1a4d2e 100%); padding: 2rem; border-radius: 16px; margin-bottom:
|
| 568 |
<h1 style="color: #4ade80; font-size: 2.5rem; margin: 0;">🐦 BirdSense Pro</h1>
|
| 569 |
-
<p style="color: #94a3b8; font-size: 1.
|
| 570 |
-
<p style="color: #64748b; font-size: 0.9rem;">
|
| 571 |
-
🤖 Uses LOCAL Ollama LLM • 10,000+ species • Multi-bird detection
|
| 572 |
-
</p>
|
| 573 |
</div>
|
| 574 |
""")
|
| 575 |
|
| 576 |
-
|
| 577 |
-
status_text = get_llm_status()
|
| 578 |
-
gr.Markdown(f"**Current LLM:** {status_text}")
|
| 579 |
|
| 580 |
with gr.Tabs():
|
| 581 |
-
|
| 582 |
-
with gr.Tab("🎤 Audio"):
|
| 583 |
gr.Markdown("""
|
| 584 |
-
###
|
| 585 |
-
|
| 586 |
-
|
|
|
|
| 587 |
""")
|
| 588 |
|
| 589 |
with gr.Row():
|
| 590 |
with gr.Column(scale=1):
|
| 591 |
audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="🎤 Bird Audio")
|
| 592 |
with gr.Row():
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
)
|
| 600 |
-
audio_btn = gr.Button("🔍 Identify with Ollama LLM", variant="primary", size="lg")
|
| 601 |
|
| 602 |
with gr.Column(scale=2):
|
| 603 |
audio_out = gr.Markdown()
|
| 604 |
|
| 605 |
-
audio_btn.click(identify_audio, [audio_in,
|
| 606 |
|
| 607 |
-
# DESCRIPTION TAB
|
| 608 |
with gr.Tab("📝 Description"):
|
| 609 |
-
gr.Markdown("""
|
| 610 |
-
### Describe the bird you saw or heard
|
| 611 |
-
|
| 612 |
-
The LLM will analyze your description and identify matching species.
|
| 613 |
-
""")
|
| 614 |
-
|
| 615 |
with gr.Row():
|
| 616 |
with gr.Column(scale=1):
|
| 617 |
-
desc_in = gr.Textbox(
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
lines=4
|
| 621 |
-
)
|
| 622 |
-
desc_btn = gr.Button("🔍 Identify with Ollama LLM", variant="primary", size="lg")
|
| 623 |
-
|
| 624 |
with gr.Column(scale=2):
|
| 625 |
desc_out = gr.Markdown()
|
| 626 |
-
|
| 627 |
desc_btn.click(identify_description, [desc_in], desc_out)
|
| 628 |
|
| 629 |
-
# IMAGE TAB
|
| 630 |
with gr.Tab("📷 Image"):
|
| 631 |
-
gr.Markdown("""
|
| 632 |
-
### Upload or capture a bird image
|
| 633 |
-
|
| 634 |
-
Colors are extracted and sent to the LLM for identification.
|
| 635 |
-
""")
|
| 636 |
-
|
| 637 |
with gr.Row():
|
| 638 |
with gr.Column(scale=1):
|
| 639 |
img_in = gr.Image(sources=["upload", "webcam"], type="numpy", label="📷 Bird Image")
|
| 640 |
-
img_btn = gr.Button("🔍 Identify
|
| 641 |
-
|
| 642 |
with gr.Column(scale=2):
|
| 643 |
img_out = gr.Markdown()
|
| 644 |
-
|
| 645 |
img_btn.click(identify_image, [img_in], img_out)
|
| 646 |
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
-
|
| 684 |
-
-
|
| 685 |
-
|
| 686 |
-
Change the model in the code: `OLLAMA_MODEL = "your-model"`
|
| 687 |
""")
|
| 688 |
|
| 689 |
gr.HTML("""
|
| 690 |
<div style="text-align: center; padding: 1rem; margin-top: 1rem; border-top: 1px solid #334155;">
|
| 691 |
-
<p style="color: #4ade80;
|
| 692 |
-
<p style="color: #64748b;">
|
| 693 |
-
Powered by LOCAL Ollama LLM • <a href="https://github.com/sohamzycus/eagv2/tree/master/birdsense" style="color: #4ade80;">GitHub</a>
|
| 694 |
-
</p>
|
| 695 |
</div>
|
| 696 |
""")
|
| 697 |
|
| 698 |
-
|
| 699 |
if __name__ == "__main__":
|
| 700 |
-
print(f"\n🐦 BirdSense Pro")
|
| 701 |
-
print(f"LLM
|
| 702 |
-
print(f"\nStarting server...")
|
| 703 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 1 |
"""
|
| 2 |
🐦 BirdSense Pro - AI Bird Identification
|
|
|
|
| 3 |
|
| 4 |
+
Integrates:
|
| 5 |
+
1. META SAM-Audio style preprocessing for bird voice separation
|
| 6 |
+
2. Ollama LLM (local) or HuggingFace API (cloud) for identification
|
| 7 |
+
3. Multi-bird detection
|
| 8 |
+
4. Audio, Image, and Description modes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
CSCR Initiative
|
| 11 |
"""
|
|
|
|
| 15 |
import scipy.signal as signal
|
| 16 |
from scipy.ndimage import gaussian_filter1d
|
| 17 |
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, Tuple, List
|
| 19 |
import json
|
|
|
|
| 20 |
import requests
|
|
|
|
| 21 |
|
| 22 |
# ================== CONFIG ==================
|
| 23 |
SAMPLE_RATE = 48000
|
|
|
|
|
|
|
| 24 |
OLLAMA_URL = "http://localhost:11434"
|
| 25 |
+
OLLAMA_MODEL = "qwen2.5:3b"
|
| 26 |
+
HF_API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-large"
|
| 27 |
+
|
| 28 |
+
# ================== META SAM-AUDIO STYLE PREPROCESSING ==================
|
| 29 |
+
"""
|
| 30 |
+
META SAM-Audio (Segment Anything in Audio) uses text prompts to separate audio sources.
|
| 31 |
+
|
| 32 |
+
Our implementation mimics this approach:
|
| 33 |
+
1. Text prompt: "bird call", "bird song", "background noise"
|
| 34 |
+
2. Spectral masking to isolate bird frequencies (500-10000 Hz)
|
| 35 |
+
3. Noise reduction using spectral gating
|
| 36 |
+
4. Source separation for multi-bird scenarios
|
| 37 |
+
|
| 38 |
+
Reference: https://ai.meta.com/research/publications/sam-audio-segment-anything-in-audio/
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
class SAMAudioProcessor:
|
| 42 |
+
"""
|
| 43 |
+
META SAM-Audio style audio processor for bird call isolation.
|
| 44 |
+
|
| 45 |
+
Uses text-guided spectral masking to separate bird calls from noise.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# SAM-Audio style text prompts for bird isolation
|
| 49 |
+
PROMPTS = {
|
| 50 |
+
"bird_call": {
|
| 51 |
+
"freq_range": (500, 10000), # Bird vocalization range
|
| 52 |
+
"description": "bird call, bird song, bird vocalization"
|
| 53 |
+
},
|
| 54 |
+
"background": {
|
| 55 |
+
"freq_range": (0, 500), # Low frequency noise
|
| 56 |
+
"description": "wind, traffic, background noise"
|
| 57 |
+
},
|
| 58 |
+
"high_noise": {
|
| 59 |
+
"freq_range": (10000, 20000), # High frequency noise
|
| 60 |
+
"description": "electronics, insects"
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def __init__(self, sample_rate: int = 48000):
|
| 65 |
+
self.sr = sample_rate
|
| 66 |
+
|
| 67 |
+
def separate_bird_calls(self, audio: np.ndarray) -> Tuple[np.ndarray, dict]:
|
| 68 |
+
"""
|
| 69 |
+
SAM-Audio style bird call separation.
|
| 70 |
+
|
| 71 |
+
Uses spectral masking guided by "bird call" prompt to isolate
|
| 72 |
+
bird vocalizations from background noise.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tuple of (isolated_bird_audio, metadata)
|
| 76 |
+
"""
|
| 77 |
+
# Compute STFT
|
| 78 |
+
n_fft = 2048
|
| 79 |
+
hop = 512
|
| 80 |
+
f, t, Zxx = signal.stft(audio, self.sr, nperseg=n_fft, noverlap=n_fft-hop)
|
| 81 |
+
magnitude = np.abs(Zxx)
|
| 82 |
+
phase = np.angle(Zxx)
|
| 83 |
+
|
| 84 |
+
# Create bird frequency mask (SAM-Audio "bird call" prompt)
|
| 85 |
+
bird_low, bird_high = self.PROMPTS["bird_call"]["freq_range"]
|
| 86 |
+
bird_mask = np.zeros_like(magnitude)
|
| 87 |
+
|
| 88 |
+
for i, freq in enumerate(f):
|
| 89 |
+
if bird_low <= freq <= bird_high:
|
| 90 |
+
# Soft mask with gaussian roll-off at edges
|
| 91 |
+
if freq < bird_low + 200:
|
| 92 |
+
weight = (freq - bird_low) / 200
|
| 93 |
+
elif freq > bird_high - 500:
|
| 94 |
+
weight = (bird_high - freq) / 500
|
| 95 |
+
else:
|
| 96 |
+
weight = 1.0
|
| 97 |
+
bird_mask[i, :] = weight
|
| 98 |
+
|
| 99 |
+
# Apply spectral gating (noise reduction)
|
| 100 |
+
noise_floor = np.percentile(magnitude, 20, axis=1, keepdims=True)
|
| 101 |
+
gate = magnitude > (noise_floor * 2)
|
| 102 |
+
bird_mask = bird_mask * gate
|
| 103 |
+
|
| 104 |
+
# Apply mask
|
| 105 |
+
bird_magnitude = magnitude * bird_mask
|
| 106 |
+
|
| 107 |
+
# Reconstruct audio
|
| 108 |
+
bird_stft = bird_magnitude * np.exp(1j * phase)
|
| 109 |
+
_, bird_audio = signal.istft(bird_stft, self.sr, nperseg=n_fft, noverlap=n_fft-hop)
|
| 110 |
+
|
| 111 |
+
# Normalize
|
| 112 |
+
if np.max(np.abs(bird_audio)) > 0:
|
| 113 |
+
bird_audio = bird_audio / np.max(np.abs(bird_audio))
|
| 114 |
+
|
| 115 |
+
# Calculate separation quality
|
| 116 |
+
original_energy = np.sum(magnitude ** 2)
|
| 117 |
+
bird_energy = np.sum(bird_magnitude ** 2)
|
| 118 |
+
separation_ratio = bird_energy / (original_energy + 1e-10)
|
| 119 |
+
|
| 120 |
+
metadata = {
|
| 121 |
+
"sam_audio_prompt": "bird call, bird song",
|
| 122 |
+
"bird_freq_range": f"{bird_low}-{bird_high} Hz",
|
| 123 |
+
"separation_ratio": float(separation_ratio),
|
| 124 |
+
"noise_reduced": True
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
return bird_audio.astype(np.float32), metadata
|
| 128 |
+
|
| 129 |
+
def detect_multiple_birds(self, audio: np.ndarray) -> List[dict]:
|
| 130 |
+
"""
|
| 131 |
+
SAM-Audio style multi-source detection.
|
| 132 |
+
|
| 133 |
+
Detects if multiple birds are calling by analyzing
|
| 134 |
+
spectral peaks in different frequency bands.
|
| 135 |
+
"""
|
| 136 |
+
f, t, Zxx = signal.stft(audio, self.sr, nperseg=2048)
|
| 137 |
+
magnitude = np.abs(Zxx)
|
| 138 |
+
|
| 139 |
+
# Define frequency bands for different bird types
|
| 140 |
+
bands = [
|
| 141 |
+
("low_freq_birds", 500, 2000), # Crows, cuckoos, coucals
|
| 142 |
+
("mid_freq_birds", 2000, 5000), # Most songbirds
|
| 143 |
+
("high_freq_birds", 5000, 10000), # Sunbirds, warblers
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
detected_sources = []
|
| 147 |
+
for band_name, low, high in bands:
|
| 148 |
+
band_idx = (f >= low) & (f <= high)
|
| 149 |
+
band_energy = np.mean(magnitude[band_idx, :])
|
| 150 |
+
|
| 151 |
+
if band_energy > 0.01: # Threshold for detection
|
| 152 |
+
detected_sources.append({
|
| 153 |
+
"band": band_name,
|
| 154 |
+
"freq_range": f"{low}-{high} Hz",
|
| 155 |
+
"energy": float(band_energy)
|
| 156 |
+
})
|
| 157 |
+
|
| 158 |
+
return detected_sources
|
| 159 |
+
|
| 160 |
|
| 161 |
+
# Global SAM-Audio processor
|
| 162 |
+
sam_audio = SAMAudioProcessor(SAMPLE_RATE)
|
| 163 |
|
| 164 |
+
|
| 165 |
+
# ================== BIRD IMAGES ==================
|
| 166 |
BIRD_IMAGES = {
|
| 167 |
"Asian Koel": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/78/Eudynamys_scolopaceus_-_Koel_male_-_Sukhna_Lake%2C_India.jpg/320px-Eudynamys_scolopaceus_-_Koel_male_-_Sukhna_Lake%2C_India.jpg",
|
| 168 |
"Indian Cuckoo": "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6b/Cuculus_micropterus.jpg/320px-Cuculus_micropterus.jpg",
|
|
|
|
| 181 |
"Greater Coucal": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d6/Greater_Coucal_%28Centropus_sinensis%29_in_Hyderabad%2C_AP_W_IMG_7544.jpg/320px-Greater_Coucal_%28Centropus_sinensis%29_in_Hyderabad%2C_AP_W_IMG_7544.jpg",
|
| 182 |
"Common Tailorbird": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Common_Tailorbird_%28Orthotomus_sutorius%29_in_Kolkata_I_IMG_2859.jpg/320px-Common_Tailorbird_%28Orthotomus_sutorius%29_in_Kolkata_I_IMG_2859.jpg",
|
| 183 |
"Green Bee-eater": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b1/Merops_orientalis_%28Pune%2C_India%29.jpg/320px-Merops_orientalis_%28Pune%2C_India%29.jpg",
|
|
|
|
|
|
|
|
|
|
| 184 |
}
|
| 185 |
DEFAULT_IMAGE = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/45/Eopsaltria_australis_-_Mogo_Campground.jpg/320px-Eopsaltria_australis_-_Mogo_Campground.jpg"
|
| 186 |
|
| 187 |
|
| 188 |
+
# ================== LLM CLIENT ==================
|
| 189 |
|
| 190 |
+
def check_ollama() -> bool:
|
| 191 |
+
"""Check if Ollama is available."""
|
| 192 |
+
try:
|
| 193 |
+
r = requests.get(f"{OLLAMA_URL}/api/tags", timeout=2)
|
| 194 |
+
return r.status_code == 200
|
| 195 |
+
except:
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def call_ollama(prompt: str, system: str = None) -> str:
|
| 200 |
+
"""Call local Ollama LLM."""
|
| 201 |
+
payload = {
|
| 202 |
+
"model": OLLAMA_MODEL,
|
| 203 |
+
"prompt": prompt,
|
| 204 |
+
"stream": False,
|
| 205 |
+
"options": {"temperature": 0.3, "num_predict": 1500}
|
| 206 |
+
}
|
| 207 |
+
if system:
|
| 208 |
+
payload["system"] = system
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
|
|
|
| 210 |
try:
|
| 211 |
+
r = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120)
|
| 212 |
+
if r.status_code == 200:
|
| 213 |
+
return r.json().get("response", "")
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Ollama error: {e}")
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def call_hf_api(prompt: str) -> str:
|
| 220 |
+
"""Call HuggingFace Inference API (fallback)."""
|
| 221 |
+
try:
|
| 222 |
+
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 1000}}
|
| 223 |
+
r = requests.post(HF_API_URL, json=payload, timeout=60)
|
| 224 |
+
if r.status_code == 200:
|
| 225 |
+
result = r.json()
|
| 226 |
+
if isinstance(result, list) and result:
|
|
|
|
|
|
|
|
|
|
| 227 |
return result[0].get("generated_text", "")
|
| 228 |
except Exception as e:
|
| 229 |
+
print(f"HF API error: {e}")
|
|
|
|
| 230 |
return None
|
| 231 |
|
| 232 |
|
| 233 |
+
def call_llm(prompt: str, system: str = None) -> str:
|
| 234 |
+
"""Call LLM - Ollama first, then HuggingFace API fallback."""
|
| 235 |
+
if check_ollama():
|
| 236 |
+
result = call_ollama(prompt, system)
|
| 237 |
+
if result:
|
| 238 |
+
return result
|
| 239 |
+
return call_hf_api(prompt)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
def get_llm_status() -> str:
|
| 243 |
+
"""Get LLM status string."""
|
| 244 |
+
if check_ollama():
|
| 245 |
+
return f"🟢 Ollama ({OLLAMA_MODEL}) LOCAL"
|
| 246 |
else:
|
| 247 |
+
return "🟡 HuggingFace API (cloud)"
|
| 248 |
|
| 249 |
|
| 250 |
# ================== AUDIO FEATURES ==================
|
| 251 |
|
| 252 |
+
@dataclass
|
| 253 |
class AudioFeatures:
|
| 254 |
+
"""Audio features after SAM-Audio preprocessing."""
|
| 255 |
duration: float
|
| 256 |
peak_frequency: float
|
| 257 |
freq_range: Tuple[float, float]
|
|
|
|
| 258 |
num_syllables: int
|
| 259 |
syllable_rate: float
|
| 260 |
is_melodic: bool
|
| 261 |
is_repetitive: bool
|
|
|
|
| 262 |
snr_db: float
|
| 263 |
+
sam_audio_metadata: dict
|
| 264 |
+
|
| 265 |
+
def to_prompt(self) -> str:
|
| 266 |
+
"""Convert to LLM prompt."""
|
| 267 |
+
freq_desc = "very low (large bird)" if self.peak_frequency < 500 else \
|
| 268 |
+
"low (crow, cuckoo)" if self.peak_frequency < 1500 else \
|
| 269 |
+
"medium (songbird)" if self.peak_frequency < 4000 else \
|
| 270 |
+
"high (warbler, sunbird)" if self.peak_frequency < 7000 else \
|
| 271 |
+
"very high (alarm call)"
|
| 272 |
|
| 273 |
+
return f"""Audio features (after SAM-Audio bird call separation):
|
| 274 |
+
- Duration: {self.duration:.1f}s
|
| 275 |
+
- Peak frequency: {self.peak_frequency:.0f} Hz ({freq_desc})
|
| 276 |
- Frequency range: {self.freq_range[0]:.0f} - {self.freq_range[1]:.0f} Hz
|
| 277 |
+
- Pattern: {"melodic" if self.is_melodic else "monotone"}, {"repetitive" if self.is_repetitive else "variable"}
|
| 278 |
+
- Syllables: {self.num_syllables} at {self.syllable_rate:.1f}/sec
|
| 279 |
+
- Recording quality: SNR {self.snr_db:.0f}dB
|
| 280 |
+
|
| 281 |
+
SAM-Audio preprocessing:
|
| 282 |
+
- Prompt used: "{self.sam_audio_metadata.get('sam_audio_prompt', 'bird call')}"
|
| 283 |
+
- Bird frequency isolation: {self.sam_audio_metadata.get('bird_freq_range', '500-10000 Hz')}
|
| 284 |
+
- Separation quality: {self.sam_audio_metadata.get('separation_ratio', 0)*100:.0f}%"""
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def extract_features(audio: np.ndarray, sr: int, sam_metadata: dict) -> AudioFeatures:
|
| 288 |
+
"""Extract features from SAM-Audio processed audio."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
duration = len(audio) / sr
|
|
|
|
| 290 |
|
| 291 |
+
# Spectral analysis
|
| 292 |
freqs, psd = signal.welch(audio, sr, nperseg=min(4096, len(audio)))
|
| 293 |
peak_freq = freqs[np.argmax(psd)]
|
| 294 |
cumsum = np.cumsum(psd) / (np.sum(psd) + 1e-10)
|
| 295 |
freq_low = freqs[np.searchsorted(cumsum, 0.10)]
|
| 296 |
freq_high = freqs[np.searchsorted(cumsum, 0.90)]
|
|
|
|
| 297 |
|
| 298 |
+
# Syllable detection
|
| 299 |
envelope = np.abs(signal.hilbert(audio))
|
| 300 |
k = int(0.02 * sr)
|
| 301 |
if k > 0:
|
| 302 |
envelope = gaussian_filter1d(envelope, k)
|
| 303 |
|
|
|
|
| 304 |
n_fft, hop = 2048, 512
|
| 305 |
_, _, Zxx = signal.stft(audio, sr, nperseg=n_fft, noverlap=n_fft-hop)
|
| 306 |
flux = np.sum(np.maximum(0, np.diff(np.abs(Zxx), axis=1)), axis=0)
|
|
|
|
| 312 |
num_syl = len(peaks)
|
| 313 |
syl_rate = num_syl / duration if duration > 0 else 0
|
| 314 |
|
| 315 |
+
# Melodic detection
|
| 316 |
is_melodic = False
|
| 317 |
if len(audio) > sr:
|
| 318 |
chunks = np.array_split(audio, min(20, max(5, int(duration*4))))
|
| 319 |
+
chunk_freqs = [freqs[np.argmax(signal.welch(c, sr, nperseg=min(1024, len(c)))[1])]
|
| 320 |
+
for c in chunks if len(c) > 512]
|
|
|
|
|
|
|
|
|
|
| 321 |
if chunk_freqs:
|
| 322 |
is_melodic = np.std(chunk_freqs) / (np.mean(chunk_freqs) + 1e-10) > 0.15
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
# SNR
|
| 325 |
noise = np.percentile(np.abs(audio), 5)
|
| 326 |
sig = np.percentile(np.abs(audio), 95)
|
|
|
|
| 330 |
duration=duration,
|
| 331 |
peak_frequency=float(peak_freq),
|
| 332 |
freq_range=(float(freq_low), float(freq_high)),
|
|
|
|
| 333 |
num_syllables=num_syl,
|
| 334 |
syllable_rate=float(syl_rate),
|
| 335 |
is_melodic=is_melodic,
|
| 336 |
is_repetitive=syl_rate > 3,
|
| 337 |
+
snr_db=float(snr),
|
| 338 |
+
sam_audio_metadata=sam_metadata
|
| 339 |
)
|
| 340 |
|
| 341 |
|
| 342 |
def preprocess_audio(audio_data: np.ndarray, sr: int) -> Tuple[np.ndarray, int]:
|
| 343 |
+
"""Basic audio preprocessing."""
|
| 344 |
if audio_data.dtype == np.int16:
|
| 345 |
audio_data = audio_data.astype(np.float32) / 32768.0
|
| 346 |
elif audio_data.dtype == np.int32:
|
|
|
|
| 357 |
sr = SAMPLE_RATE
|
| 358 |
|
| 359 |
audio_data = audio_data / (np.max(np.abs(audio_data)) + 1e-8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
return audio_data, sr
|
| 361 |
|
| 362 |
|
| 363 |
# ================== LLM PROMPTS ==================
|
| 364 |
|
| 365 |
+
SYSTEM_PROMPT = """You are an expert ornithologist with knowledge of 10,000+ bird species worldwide.
|
| 366 |
You specialize in Indian birds (1,300+ species).
|
| 367 |
|
| 368 |
+
The audio has been preprocessed using META SAM-Audio style separation to isolate bird calls.
|
| 369 |
+
|
| 370 |
+
Your task: Identify ALL bird species that could be present in the recording.
|
| 371 |
|
| 372 |
+
IMPORTANT:
|
| 373 |
+
1. List ALL birds with confidence >= 50%
|
| 374 |
+
2. Multiple birds may be calling simultaneously
|
| 375 |
+
3. Consider the audio features carefully
|
| 376 |
+
4. Provide scientific names
|
| 377 |
|
| 378 |
+
Respond in JSON format:
|
| 379 |
{
|
| 380 |
"birds": [
|
| 381 |
+
{"name": "Common Name", "scientific_name": "Genus species", "confidence": 85, "reasoning": "Why this bird"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
],
|
| 383 |
+
"analysis": "Overall analysis"
|
| 384 |
}"""
|
| 385 |
|
| 386 |
|
| 387 |
def get_bird_image(name: str) -> str:
|
| 388 |
+
"""Get bird image URL."""
|
| 389 |
if name in BIRD_IMAGES:
|
| 390 |
return BIRD_IMAGES[name]
|
|
|
|
| 391 |
for bird, url in BIRD_IMAGES.items():
|
| 392 |
+
if bird.lower() in name.lower() or name.lower() in bird.lower():
|
| 393 |
return url
|
| 394 |
return DEFAULT_IMAGE
|
| 395 |
|
| 396 |
|
| 397 |
+
def format_results(response: str, features: AudioFeatures = None) -> str:
|
| 398 |
+
"""Format LLM response with images."""
|
| 399 |
+
if not response:
|
| 400 |
return "### ⚠️ No response from LLM"
|
| 401 |
|
| 402 |
+
output = "## 🐦 Birds Identified\n\n"
|
| 403 |
+
|
| 404 |
+
# Add SAM-Audio info
|
| 405 |
+
if features:
|
| 406 |
+
output += f"**🔊 SAM-Audio Preprocessing:**\n"
|
| 407 |
+
output += f"- Prompt: `{features.sam_audio_metadata.get('sam_audio_prompt', 'bird call')}`\n"
|
| 408 |
+
output += f"- Isolation: {features.sam_audio_metadata.get('bird_freq_range', '500-10000 Hz')}\n"
|
| 409 |
+
output += f"- Quality: {features.sam_audio_metadata.get('separation_ratio', 0)*100:.0f}%\n\n"
|
| 410 |
+
|
| 411 |
try:
|
| 412 |
+
start = response.find('{')
|
| 413 |
+
end = response.rfind('}') + 1
|
|
|
|
| 414 |
if start >= 0 and end > start:
|
| 415 |
+
data = json.loads(response[start:end])
|
| 416 |
+
birds = data.get("birds", [])
|
| 417 |
+
analysis = data.get("analysis", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
+
if analysis:
|
| 420 |
+
output += f"*{analysis}*\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
+
for i, bird in enumerate(birds, 1):
|
| 423 |
+
name = bird.get("name", "Unknown")
|
| 424 |
+
scientific = bird.get("scientific_name", "")
|
| 425 |
+
conf = bird.get("confidence", 0)
|
| 426 |
+
reason = bird.get("reasoning", "")
|
| 427 |
+
|
| 428 |
+
img = get_bird_image(name)
|
| 429 |
+
badge = "🟢 HIGH" if conf >= 80 else "🟡 MEDIUM" if conf >= 60 else "🔴 LOW"
|
| 430 |
+
|
| 431 |
+
output += f"""
|
| 432 |
---
|
| 433 |
|
| 434 |
### {i}. **{name}** ({conf}%) {badge}
|
|
|
|
| 437 |
|
| 438 |
**Scientific Name:** _{scientific}_
|
| 439 |
|
| 440 |
+
**Why:** {reason}
|
| 441 |
|
| 442 |
"""
|
| 443 |
+
return output
|
| 444 |
+
except:
|
| 445 |
+
pass
|
| 446 |
+
|
| 447 |
+
return output + f"\n\n### AI Response:\n{response}"
|
| 448 |
|
| 449 |
|
| 450 |
+
# ================== MAIN FUNCTIONS ==================
|
| 451 |
|
| 452 |
def identify_audio(audio, location: str = "", month: str = ""):
|
| 453 |
+
"""Identify bird from audio using SAM-Audio + LLM."""
|
| 454 |
if audio is None:
|
| 455 |
+
return "### ⚠️ Please record or upload audio"
|
| 456 |
|
| 457 |
status = get_llm_status()
|
| 458 |
+
yield f"### 🔄 Processing with SAM-Audio...\n\n**LLM:** {status}"
|
| 459 |
|
| 460 |
try:
|
| 461 |
sr, audio_data = audio
|
| 462 |
audio_data, sr = preprocess_audio(audio_data, sr)
|
| 463 |
|
| 464 |
+
# ===== SAM-AUDIO PREPROCESSING =====
|
| 465 |
+
yield f"### 🔄 Applying META SAM-Audio bird call separation...\n\n**LLM:** {status}"
|
| 466 |
+
|
| 467 |
+
bird_audio, sam_metadata = sam_audio.separate_bird_calls(audio_data)
|
| 468 |
+
multi_sources = sam_audio.detect_multiple_birds(bird_audio)
|
| 469 |
|
| 470 |
+
sam_info = f"""**SAM-Audio Results:**
|
| 471 |
+
- Prompt: "bird call, bird song"
|
| 472 |
+
- Frequency isolation: {sam_metadata['bird_freq_range']}
|
| 473 |
+
- Separation quality: {sam_metadata['separation_ratio']*100:.0f}%
|
| 474 |
+
- Potential sources detected: {len(multi_sources)} frequency bands active
|
| 475 |
+
"""
|
| 476 |
+
yield f"### 🔄 SAM-Audio complete. Extracting features...\n\n{sam_info}\n\n**LLM:** {status}"
|
| 477 |
+
|
| 478 |
+
# Extract features from SAM-Audio processed audio
|
| 479 |
+
features = extract_features(bird_audio, sr, sam_metadata)
|
| 480 |
+
|
| 481 |
+
# Build LLM prompt
|
| 482 |
prompt = f"""Identify the bird(s) in this recording:
|
| 483 |
|
| 484 |
+
{features.to_prompt()}
|
| 485 |
+
|
| 486 |
"""
|
| 487 |
if location:
|
| 488 |
+
prompt += f"Location: {location}\n"
|
| 489 |
if month:
|
| 490 |
+
prompt += f"Month: {month}\n"
|
| 491 |
|
| 492 |
+
if len(multi_sources) > 1:
|
| 493 |
+
prompt += f"\nNote: SAM-Audio detected activity in {len(multi_sources)} frequency bands - likely multiple birds!\n"
|
| 494 |
|
| 495 |
+
prompt += "\nIdentify ALL birds (confidence >= 50%)."
|
| 496 |
|
| 497 |
+
yield f"### 🔄 Consulting LLM...\n\n{sam_info}\n\n**LLM:** {status}"
|
| 498 |
+
|
| 499 |
+
response = call_llm(prompt, SYSTEM_PROMPT)
|
| 500 |
|
| 501 |
if response:
|
| 502 |
+
result = format_results(response, features)
|
| 503 |
+
result += f"\n\n---\n\n### 📊 Audio Analysis\n{features.to_prompt()}\n\n**LLM:** {status}"
|
|
|
|
| 504 |
yield result
|
| 505 |
else:
|
| 506 |
+
yield f"""### ⚠️ LLM not available
|
| 507 |
|
| 508 |
+
{sam_info}
|
| 509 |
|
| 510 |
+
**Audio features detected:**
|
| 511 |
+
{features.to_prompt()}
|
| 512 |
|
| 513 |
+
**To fix (if using local):**
|
| 514 |
+
1. Start Ollama: `ollama serve`
|
| 515 |
+
2. Pull model: `ollama pull {OLLAMA_MODEL}`
|
|
|
|
| 516 |
"""
|
| 517 |
|
| 518 |
except Exception as e:
|
|
|
|
| 522 |
def identify_description(description: str):
|
| 523 |
"""Identify bird from description using LLM."""
|
| 524 |
if not description or len(description.strip()) < 5:
|
| 525 |
+
return "### ⚠️ Please enter a description"
|
| 526 |
|
| 527 |
status = get_llm_status()
|
| 528 |
+
yield f"### 🔄 Analyzing with LLM...\n\n**LLM:** {status}"
|
| 529 |
|
| 530 |
+
prompt = f"""Identify the bird(s) from this description:
|
| 531 |
|
| 532 |
{description}
|
| 533 |
|
| 534 |
+
Focus on Indian birds. List all matches with confidence >= 50%."""
|
| 535 |
|
| 536 |
+
response = call_llm(prompt, SYSTEM_PROMPT)
|
| 537 |
|
| 538 |
if response:
|
| 539 |
+
yield format_results(response) + f"\n\n**LLM:** {status}"
|
|
|
|
|
|
|
| 540 |
else:
|
| 541 |
+
yield f"### ⚠️ LLM not available\n\n**LLM:** {status}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
|
| 544 |
def identify_image(image):
|
| 545 |
"""Identify bird from image using LLM."""
|
| 546 |
if image is None:
|
| 547 |
+
return "### ⚠️ Please upload an image"
|
| 548 |
|
| 549 |
status = get_llm_status()
|
| 550 |
+
yield f"### 🔄 Analyzing image...\n\n**LLM:** {status}"
|
| 551 |
|
| 552 |
try:
|
| 553 |
+
img = np.array(image) if not isinstance(image, np.ndarray) else image
|
|
|
|
|
|
|
|
|
|
| 554 |
|
|
|
|
| 555 |
colors = []
|
| 556 |
if len(img.shape) == 3 and img.shape[2] >= 3:
|
| 557 |
r, g, b = np.mean(img[:,:,0]), np.mean(img[:,:,1]), np.mean(img[:,:,2])
|
|
|
|
| 564 |
|
| 565 |
color_desc = ", ".join(colors) if colors else "mixed"
|
| 566 |
|
| 567 |
+
prompt = f"""Identify the bird from image analysis:
|
|
|
|
|
|
|
| 568 |
|
| 569 |
+
Detected colors: {color_desc}
|
|
|
|
| 570 |
|
| 571 |
+
What Indian bird species match these colors? List all with confidence >= 50%."""
|
|
|
|
| 572 |
|
| 573 |
+
response = call_llm(prompt, SYSTEM_PROMPT)
|
| 574 |
|
| 575 |
if response:
|
| 576 |
+
result = f"**Detected colors:** {color_desc}\n\n"
|
| 577 |
+
result += format_results(response)
|
| 578 |
result += f"\n\n**LLM:** {status}"
|
| 579 |
yield result
|
| 580 |
else:
|
| 581 |
+
yield f"### ⚠️ LLM not available\n\n**Detected colors:** {color_desc}\n\n**LLM:** {status}"
|
| 582 |
|
| 583 |
except Exception as e:
|
| 584 |
yield f"### ❌ Error: {str(e)}"
|
|
|
|
| 586 |
|
| 587 |
# ================== GRADIO UI ==================
|
| 588 |
|
| 589 |
+
with gr.Blocks(title="🐦 BirdSense Pro") as demo:
|
| 590 |
|
| 591 |
gr.HTML("""
|
| 592 |
+
<div style="text-align: center; background: linear-gradient(135deg, #1a4d2e 0%, #2d5a3e 50%, #1a4d2e 100%); padding: 2rem; border-radius: 16px; margin-bottom: 1rem;">
|
| 593 |
<h1 style="color: #4ade80; font-size: 2.5rem; margin: 0;">🐦 BirdSense Pro</h1>
|
| 594 |
+
<p style="color: #94a3b8; font-size: 1.1rem;">META SAM-Audio + LLM Bird Identification</p>
|
| 595 |
+
<p style="color: #64748b; font-size: 0.9rem;">SAM-Audio preprocessing • 10,000+ species • Multi-bird detection</p>
|
|
|
|
|
|
|
| 596 |
</div>
|
| 597 |
""")
|
| 598 |
|
| 599 |
+
gr.Markdown(f"**LLM Status:** {get_llm_status()}")
|
|
|
|
|
|
|
| 600 |
|
| 601 |
with gr.Tabs():
|
| 602 |
+
with gr.Tab("🎤 Audio (SAM-Audio + LLM)"):
|
|
|
|
| 603 |
gr.Markdown("""
|
| 604 |
+
### How it works:
|
| 605 |
+
1. **META SAM-Audio** separates bird calls from noise (using "bird call" prompt)
|
| 606 |
+
2. **Features extracted** from isolated bird audio
|
| 607 |
+
3. **LLM identifies** all matching species (10,000+ known)
|
| 608 |
""")
|
| 609 |
|
| 610 |
with gr.Row():
|
| 611 |
with gr.Column(scale=1):
|
| 612 |
audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="🎤 Bird Audio")
|
| 613 |
with gr.Row():
|
| 614 |
+
loc = gr.Textbox(label="📍 Location", placeholder="e.g., Western Ghats")
|
| 615 |
+
month = gr.Dropdown(label="📅 Month", choices=[""] + [
|
| 616 |
+
"January", "February", "March", "April", "May", "June",
|
| 617 |
+
"July", "August", "September", "October", "November", "December"
|
| 618 |
+
])
|
| 619 |
+
audio_btn = gr.Button("🔍 Identify (SAM-Audio + LLM)", variant="primary", size="lg")
|
|
|
|
|
|
|
| 620 |
|
| 621 |
with gr.Column(scale=2):
|
| 622 |
audio_out = gr.Markdown()
|
| 623 |
|
| 624 |
+
audio_btn.click(identify_audio, [audio_in, loc, month], audio_out)
|
| 625 |
|
|
|
|
| 626 |
with gr.Tab("📝 Description"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
with gr.Row():
|
| 628 |
with gr.Column(scale=1):
|
| 629 |
+
desc_in = gr.Textbox(label="Bird Description", lines=4,
|
| 630 |
+
placeholder="Example: Small green bird with red forehead, making tuk-tuk sound")
|
| 631 |
+
desc_btn = gr.Button("🔍 Identify", variant="primary", size="lg")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
with gr.Column(scale=2):
|
| 633 |
desc_out = gr.Markdown()
|
|
|
|
| 634 |
desc_btn.click(identify_description, [desc_in], desc_out)
|
| 635 |
|
|
|
|
| 636 |
with gr.Tab("📷 Image"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
with gr.Row():
|
| 638 |
with gr.Column(scale=1):
|
| 639 |
img_in = gr.Image(sources=["upload", "webcam"], type="numpy", label="📷 Bird Image")
|
| 640 |
+
img_btn = gr.Button("🔍 Identify", variant="primary", size="lg")
|
|
|
|
| 641 |
with gr.Column(scale=2):
|
| 642 |
img_out = gr.Markdown()
|
|
|
|
| 643 |
img_btn.click(identify_image, [img_in], img_out)
|
| 644 |
|
| 645 |
+
with gr.Tab("ℹ️ SAM-Audio"):
|
| 646 |
+
gr.Markdown("""
|
| 647 |
+
## META SAM-Audio Integration
|
| 648 |
+
|
| 649 |
+
**SAM-Audio** (Segment Anything in Audio) by Meta AI uses text prompts to separate audio sources.
|
| 650 |
+
|
| 651 |
+
### How We Use It:
|
| 652 |
+
|
| 653 |
+
```
|
| 654 |
+
Raw Audio Recording
|
| 655 |
+
↓
|
| 656 |
+
SAM-Audio Preprocessing
|
| 657 |
+
- Prompt: "bird call, bird song"
|
| 658 |
+
- Isolates frequencies 500-10000 Hz
|
| 659 |
+
- Removes background noise
|
| 660 |
+
- Spectral gating
|
| 661 |
+
↓
|
| 662 |
+
Clean Bird Audio
|
| 663 |
+
↓
|
| 664 |
+
Feature Extraction
|
| 665 |
+
↓
|
| 666 |
+
LLM Identification
|
| 667 |
+
```
|
| 668 |
+
|
| 669 |
+
### SAM-Audio Prompts Used:
|
| 670 |
+
- `"bird call"` - General bird vocalizations
|
| 671 |
+
- `"bird song"` - Melodic bird sounds
|
| 672 |
+
- `"background noise"` - To remove (wind, traffic)
|
| 673 |
+
|
| 674 |
+
### Multi-Bird Detection:
|
| 675 |
+
SAM-Audio analyzes different frequency bands:
|
| 676 |
+
- **Low (500-2000 Hz):** Crows, cuckoos, coucals
|
| 677 |
+
- **Mid (2000-5000 Hz):** Most songbirds
|
| 678 |
+
- **High (5000-10000 Hz):** Sunbirds, warblers
|
| 679 |
+
|
| 680 |
+
### References:
|
| 681 |
+
- [META SAM-Audio Paper](https://ai.meta.com/research/publications/sam-audio-segment-anything-in-audio/)
|
| 682 |
+
- [SAM-Audio Demo](https://ai.meta.com/samaudio/)
|
| 683 |
+
- [HuggingFace Model](https://huggingface.co/facebook/sam-audio-large)
|
|
|
|
| 684 |
""")
|
| 685 |
|
| 686 |
gr.HTML("""
|
| 687 |
<div style="text-align: center; padding: 1rem; margin-top: 1rem; border-top: 1px solid #334155;">
|
| 688 |
+
<p style="color: #4ade80;">🐦 BirdSense Pro - CSCR Initiative</p>
|
| 689 |
+
<p style="color: #64748b;">META SAM-Audio + Ollama LLM</p>
|
|
|
|
|
|
|
| 690 |
</div>
|
| 691 |
""")
|
| 692 |
|
|
|
|
| 693 |
if __name__ == "__main__":
|
| 694 |
+
print(f"\n🐦 BirdSense Pro with META SAM-Audio")
|
| 695 |
+
print(f"LLM: {get_llm_status()}")
|
|
|
|
| 696 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
audio/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BirdSense Audio Processing Module."""
|
| 2 |
+
|
| 3 |
+
from .preprocessor import AudioPreprocessor
|
| 4 |
+
from .encoder import AudioEncoder
|
| 5 |
+
from .augmentation import AudioAugmenter
|
| 6 |
+
|
| 7 |
+
__all__ = ["AudioPreprocessor", "AudioEncoder", "AudioAugmenter"]
|
| 8 |
+
|
audio/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (407 Bytes). View file
|
|
|
audio/__pycache__/augmentation.cpython-314.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
audio/__pycache__/encoder.cpython-314.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
audio/__pycache__/preprocessor.cpython-314.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
audio/__pycache__/sam_audio.cpython-314.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
audio/augmentation.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Augmentation for BirdSense.
|
| 3 |
+
|
| 4 |
+
Provides augmentation techniques to make the model robust to:
|
| 5 |
+
- Different noise conditions (urban, forest, rain, wind)
|
| 6 |
+
- Recording quality variations
|
| 7 |
+
- Distance/amplitude variations
|
| 8 |
+
- Pitch variations (natural variation in bird calls)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Optional, List, Tuple
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class AugmentationConfig:
|
| 19 |
+
"""Configuration for audio augmentation."""
|
| 20 |
+
# Noise injection
|
| 21 |
+
add_noise: bool = True
|
| 22 |
+
noise_types: List[str] = None # 'gaussian', 'pink', 'urban', 'forest'
|
| 23 |
+
min_snr_db: float = 3.0
|
| 24 |
+
max_snr_db: float = 30.0
|
| 25 |
+
|
| 26 |
+
# Time stretching
|
| 27 |
+
time_stretch: bool = True
|
| 28 |
+
min_stretch_rate: float = 0.8
|
| 29 |
+
max_stretch_rate: float = 1.2
|
| 30 |
+
|
| 31 |
+
# Pitch shifting
|
| 32 |
+
pitch_shift: bool = True
|
| 33 |
+
min_semitones: float = -2.0
|
| 34 |
+
max_semitones: float = 2.0
|
| 35 |
+
|
| 36 |
+
# Amplitude variation
|
| 37 |
+
amplitude_variation: bool = True
|
| 38 |
+
min_gain_db: float = -12.0
|
| 39 |
+
max_gain_db: float = 6.0
|
| 40 |
+
|
| 41 |
+
# Time masking (simulate brief interruptions)
|
| 42 |
+
time_mask: bool = True
|
| 43 |
+
max_mask_ratio: float = 0.1
|
| 44 |
+
|
| 45 |
+
# Frequency masking (simulate frequency-specific noise)
|
| 46 |
+
freq_mask: bool = True
|
| 47 |
+
max_freq_mask_bins: int = 20
|
| 48 |
+
|
| 49 |
+
def __post_init__(self):
|
| 50 |
+
if self.noise_types is None:
|
| 51 |
+
self.noise_types = ['gaussian', 'pink', 'urban', 'forest']
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class AudioAugmenter:
|
| 55 |
+
"""
|
| 56 |
+
Audio augmentation pipeline for training robust bird classifiers.
|
| 57 |
+
|
| 58 |
+
Simulates real-world recording conditions including:
|
| 59 |
+
- Environmental noise (traffic, wind, rain, other birds)
|
| 60 |
+
- Recording equipment variations
|
| 61 |
+
- Distance variations (feeble vs. close recordings)
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, config: Optional[AugmentationConfig] = None, seed: Optional[int] = None):
|
| 65 |
+
self.config = config or AugmentationConfig()
|
| 66 |
+
if seed is not None:
|
| 67 |
+
random.seed(seed)
|
| 68 |
+
np.random.seed(seed)
|
| 69 |
+
|
| 70 |
+
def add_gaussian_noise(
|
| 71 |
+
self,
|
| 72 |
+
audio: np.ndarray,
|
| 73 |
+
snr_db: float
|
| 74 |
+
) -> np.ndarray:
|
| 75 |
+
"""Add Gaussian white noise at specified SNR."""
|
| 76 |
+
signal_power = np.mean(audio ** 2)
|
| 77 |
+
noise_power = signal_power / (10 ** (snr_db / 10))
|
| 78 |
+
noise = np.random.normal(0, np.sqrt(noise_power), len(audio))
|
| 79 |
+
return (audio + noise).astype(np.float32)
|
| 80 |
+
|
| 81 |
+
def add_pink_noise(
|
| 82 |
+
self,
|
| 83 |
+
audio: np.ndarray,
|
| 84 |
+
snr_db: float
|
| 85 |
+
) -> np.ndarray:
|
| 86 |
+
"""
|
| 87 |
+
Add pink (1/f) noise - more natural sounding than white noise.
|
| 88 |
+
Common in environmental recordings.
|
| 89 |
+
"""
|
| 90 |
+
n_samples = len(audio)
|
| 91 |
+
|
| 92 |
+
# Generate pink noise using spectral shaping
|
| 93 |
+
# Generate white noise
|
| 94 |
+
white = np.random.randn(n_samples)
|
| 95 |
+
|
| 96 |
+
# Apply 1/f filter in frequency domain
|
| 97 |
+
fft = np.fft.rfft(white)
|
| 98 |
+
freqs = np.fft.rfftfreq(n_samples)
|
| 99 |
+
freqs[0] = 1e-10 # Avoid division by zero
|
| 100 |
+
|
| 101 |
+
# Pink noise has 1/f power spectrum, so 1/sqrt(f) amplitude
|
| 102 |
+
pink_filter = 1.0 / np.sqrt(freqs + 1e-10)
|
| 103 |
+
pink_filter = pink_filter / np.max(pink_filter)
|
| 104 |
+
|
| 105 |
+
pink = np.fft.irfft(fft * pink_filter, n=n_samples)
|
| 106 |
+
pink = pink - np.mean(pink)
|
| 107 |
+
pink = pink / (np.max(np.abs(pink)) + 1e-8)
|
| 108 |
+
|
| 109 |
+
# Scale to desired SNR
|
| 110 |
+
signal_power = np.mean(audio ** 2)
|
| 111 |
+
noise_power = signal_power / (10 ** (snr_db / 10))
|
| 112 |
+
pink = pink * np.sqrt(noise_power)
|
| 113 |
+
|
| 114 |
+
return (audio + pink).astype(np.float32)
|
| 115 |
+
|
| 116 |
+
def add_urban_noise(
|
| 117 |
+
self,
|
| 118 |
+
audio: np.ndarray,
|
| 119 |
+
sr: int,
|
| 120 |
+
snr_db: float
|
| 121 |
+
) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Simulate urban noise (low-frequency rumble + occasional spikes).
|
| 124 |
+
Models traffic, construction, and city ambience.
|
| 125 |
+
"""
|
| 126 |
+
n_samples = len(audio)
|
| 127 |
+
|
| 128 |
+
# Low-frequency rumble (traffic)
|
| 129 |
+
t = np.arange(n_samples) / sr
|
| 130 |
+
rumble = np.sin(2 * np.pi * 50 * t) * 0.5 + np.sin(2 * np.pi * 100 * t) * 0.3
|
| 131 |
+
rumble += np.random.randn(n_samples) * 0.2
|
| 132 |
+
|
| 133 |
+
# Occasional impulses (cars passing, doors)
|
| 134 |
+
n_impulses = random.randint(2, 8)
|
| 135 |
+
for _ in range(n_impulses):
|
| 136 |
+
pos = random.randint(0, n_samples - 100)
|
| 137 |
+
impulse_len = random.randint(50, 200)
|
| 138 |
+
decay = np.exp(-np.arange(impulse_len) / (impulse_len * 0.3))
|
| 139 |
+
impulse = np.random.randn(impulse_len) * decay
|
| 140 |
+
rumble[pos:pos + impulse_len] += impulse * random.uniform(0.5, 2.0)
|
| 141 |
+
|
| 142 |
+
rumble = rumble / (np.max(np.abs(rumble)) + 1e-8)
|
| 143 |
+
|
| 144 |
+
# Scale to desired SNR
|
| 145 |
+
signal_power = np.mean(audio ** 2)
|
| 146 |
+
noise_power = signal_power / (10 ** (snr_db / 10))
|
| 147 |
+
rumble = rumble * np.sqrt(noise_power)
|
| 148 |
+
|
| 149 |
+
return (audio + rumble).astype(np.float32)
|
| 150 |
+
|
| 151 |
+
def add_forest_noise(
|
| 152 |
+
self,
|
| 153 |
+
audio: np.ndarray,
|
| 154 |
+
sr: int,
|
| 155 |
+
snr_db: float
|
| 156 |
+
) -> np.ndarray:
|
| 157 |
+
"""
|
| 158 |
+
Simulate forest ambient noise (insects, wind in leaves, water).
|
| 159 |
+
"""
|
| 160 |
+
n_samples = len(audio)
|
| 161 |
+
|
| 162 |
+
# Base: filtered noise (wind through leaves)
|
| 163 |
+
wind = np.random.randn(n_samples)
|
| 164 |
+
# Apply bandpass to simulate rustling (200-4000 Hz)
|
| 165 |
+
from scipy import signal as sig
|
| 166 |
+
nyquist = sr / 2
|
| 167 |
+
b, a = sig.butter(2, [200 / nyquist, 4000 / nyquist], btype='band')
|
| 168 |
+
wind = sig.filtfilt(b, a, wind)
|
| 169 |
+
|
| 170 |
+
# Add modulation (gusts)
|
| 171 |
+
t = np.arange(n_samples) / sr
|
| 172 |
+
modulation = 0.5 + 0.5 * np.sin(2 * np.pi * 0.1 * t + random.random() * 2 * np.pi)
|
| 173 |
+
wind = wind * modulation
|
| 174 |
+
|
| 175 |
+
# Add some insect-like chirps (high frequency components)
|
| 176 |
+
insect_freq = random.uniform(4000, 8000)
|
| 177 |
+
insect = np.sin(2 * np.pi * insect_freq * t) * 0.1
|
| 178 |
+
insect_modulation = np.random.rand(n_samples) > 0.7
|
| 179 |
+
insect = insect * insect_modulation.astype(float)
|
| 180 |
+
|
| 181 |
+
forest = wind * 0.8 + insect * 0.2
|
| 182 |
+
forest = forest / (np.max(np.abs(forest)) + 1e-8)
|
| 183 |
+
|
| 184 |
+
# Scale to desired SNR
|
| 185 |
+
signal_power = np.mean(audio ** 2)
|
| 186 |
+
noise_power = signal_power / (10 ** (snr_db / 10))
|
| 187 |
+
forest = forest * np.sqrt(noise_power)
|
| 188 |
+
|
| 189 |
+
return (audio + forest).astype(np.float32)
|
| 190 |
+
|
| 191 |
+
def add_noise(
|
| 192 |
+
self,
|
| 193 |
+
audio: np.ndarray,
|
| 194 |
+
sr: int,
|
| 195 |
+
noise_type: Optional[str] = None,
|
| 196 |
+
snr_db: Optional[float] = None
|
| 197 |
+
) -> np.ndarray:
|
| 198 |
+
"""
|
| 199 |
+
Add noise of specified type at random SNR.
|
| 200 |
+
"""
|
| 201 |
+
if snr_db is None:
|
| 202 |
+
snr_db = random.uniform(self.config.min_snr_db, self.config.max_snr_db)
|
| 203 |
+
|
| 204 |
+
if noise_type is None:
|
| 205 |
+
noise_type = random.choice(self.config.noise_types)
|
| 206 |
+
|
| 207 |
+
if noise_type == 'gaussian':
|
| 208 |
+
return self.add_gaussian_noise(audio, snr_db)
|
| 209 |
+
elif noise_type == 'pink':
|
| 210 |
+
return self.add_pink_noise(audio, snr_db)
|
| 211 |
+
elif noise_type == 'urban':
|
| 212 |
+
return self.add_urban_noise(audio, sr, snr_db)
|
| 213 |
+
elif noise_type == 'forest':
|
| 214 |
+
return self.add_forest_noise(audio, sr, snr_db)
|
| 215 |
+
else:
|
| 216 |
+
return self.add_gaussian_noise(audio, snr_db)
|
| 217 |
+
|
| 218 |
+
def time_stretch(
|
| 219 |
+
self,
|
| 220 |
+
audio: np.ndarray,
|
| 221 |
+
rate: Optional[float] = None
|
| 222 |
+
) -> np.ndarray:
|
| 223 |
+
"""
|
| 224 |
+
Time-stretch audio without changing pitch.
|
| 225 |
+
Uses simple resampling for efficiency.
|
| 226 |
+
"""
|
| 227 |
+
if rate is None:
|
| 228 |
+
rate = random.uniform(
|
| 229 |
+
self.config.min_stretch_rate,
|
| 230 |
+
self.config.max_stretch_rate
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Simple linear interpolation stretching
|
| 234 |
+
original_len = len(audio)
|
| 235 |
+
new_len = int(original_len / rate)
|
| 236 |
+
|
| 237 |
+
x_old = np.linspace(0, 1, original_len)
|
| 238 |
+
x_new = np.linspace(0, 1, new_len)
|
| 239 |
+
|
| 240 |
+
stretched = np.interp(x_new, x_old, audio)
|
| 241 |
+
|
| 242 |
+
# Adjust to original length
|
| 243 |
+
if len(stretched) > original_len:
|
| 244 |
+
stretched = stretched[:original_len]
|
| 245 |
+
elif len(stretched) < original_len:
|
| 246 |
+
stretched = np.pad(stretched, (0, original_len - len(stretched)), mode='constant')
|
| 247 |
+
|
| 248 |
+
return stretched.astype(np.float32)
|
| 249 |
+
|
| 250 |
+
def pitch_shift(
|
| 251 |
+
self,
|
| 252 |
+
audio: np.ndarray,
|
| 253 |
+
sr: int,
|
| 254 |
+
semitones: Optional[float] = None
|
| 255 |
+
) -> np.ndarray:
|
| 256 |
+
"""
|
| 257 |
+
Shift pitch by specified semitones.
|
| 258 |
+
Simplified implementation using resampling.
|
| 259 |
+
"""
|
| 260 |
+
if semitones is None:
|
| 261 |
+
semitones = random.uniform(
|
| 262 |
+
self.config.min_semitones,
|
| 263 |
+
self.config.max_semitones
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Pitch shift factor
|
| 267 |
+
factor = 2 ** (semitones / 12.0)
|
| 268 |
+
|
| 269 |
+
# Resample then time-stretch back
|
| 270 |
+
original_len = len(audio)
|
| 271 |
+
new_len = int(original_len / factor)
|
| 272 |
+
|
| 273 |
+
# First resample to change pitch
|
| 274 |
+
x_old = np.linspace(0, 1, original_len)
|
| 275 |
+
x_new = np.linspace(0, 1, new_len)
|
| 276 |
+
resampled = np.interp(x_new, x_old, audio)
|
| 277 |
+
|
| 278 |
+
# Then stretch back to original length
|
| 279 |
+
x_stretch = np.linspace(0, 1, len(resampled))
|
| 280 |
+
x_target = np.linspace(0, 1, original_len)
|
| 281 |
+
shifted = np.interp(x_target, x_stretch, resampled)
|
| 282 |
+
|
| 283 |
+
return shifted.astype(np.float32)
|
| 284 |
+
|
| 285 |
+
def apply_gain(
|
| 286 |
+
self,
|
| 287 |
+
audio: np.ndarray,
|
| 288 |
+
gain_db: Optional[float] = None
|
| 289 |
+
) -> np.ndarray:
|
| 290 |
+
"""
|
| 291 |
+
Apply gain to simulate distance/recording level variations.
|
| 292 |
+
"""
|
| 293 |
+
if gain_db is None:
|
| 294 |
+
gain_db = random.uniform(
|
| 295 |
+
self.config.min_gain_db,
|
| 296 |
+
self.config.max_gain_db
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
gain_linear = 10 ** (gain_db / 20)
|
| 300 |
+
audio = audio * gain_linear
|
| 301 |
+
|
| 302 |
+
# Soft clip to avoid harsh distortion
|
| 303 |
+
return np.tanh(audio).astype(np.float32)
|
| 304 |
+
|
| 305 |
+
def time_mask(
|
| 306 |
+
self,
|
| 307 |
+
spectrogram: np.ndarray
|
| 308 |
+
) -> np.ndarray:
|
| 309 |
+
"""
|
| 310 |
+
Apply time masking to spectrogram (SpecAugment technique).
|
| 311 |
+
"""
|
| 312 |
+
n_mels, n_frames = spectrogram.shape
|
| 313 |
+
max_mask_width = int(n_frames * self.config.max_mask_ratio)
|
| 314 |
+
|
| 315 |
+
if max_mask_width < 2:
|
| 316 |
+
return spectrogram
|
| 317 |
+
|
| 318 |
+
mask_width = random.randint(1, max_mask_width)
|
| 319 |
+
mask_start = random.randint(0, n_frames - mask_width)
|
| 320 |
+
|
| 321 |
+
masked = spectrogram.copy()
|
| 322 |
+
masked[:, mask_start:mask_start + mask_width] = 0
|
| 323 |
+
|
| 324 |
+
return masked
|
| 325 |
+
|
| 326 |
+
def freq_mask(
|
| 327 |
+
self,
|
| 328 |
+
spectrogram: np.ndarray
|
| 329 |
+
) -> np.ndarray:
|
| 330 |
+
"""
|
| 331 |
+
Apply frequency masking to spectrogram (SpecAugment technique).
|
| 332 |
+
"""
|
| 333 |
+
n_mels, n_frames = spectrogram.shape
|
| 334 |
+
max_mask_bins = min(self.config.max_freq_mask_bins, n_mels // 4)
|
| 335 |
+
|
| 336 |
+
if max_mask_bins < 2:
|
| 337 |
+
return spectrogram
|
| 338 |
+
|
| 339 |
+
mask_bins = random.randint(1, max_mask_bins)
|
| 340 |
+
mask_start = random.randint(0, n_mels - mask_bins)
|
| 341 |
+
|
| 342 |
+
masked = spectrogram.copy()
|
| 343 |
+
masked[mask_start:mask_start + mask_bins, :] = 0
|
| 344 |
+
|
| 345 |
+
return masked
|
| 346 |
+
|
| 347 |
+
def augment_audio(
|
| 348 |
+
self,
|
| 349 |
+
audio: np.ndarray,
|
| 350 |
+
sr: int,
|
| 351 |
+
augmentations: Optional[List[str]] = None
|
| 352 |
+
) -> np.ndarray:
|
| 353 |
+
"""
|
| 354 |
+
Apply a random subset of augmentations to audio.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
audio: Input audio waveform
|
| 358 |
+
sr: Sample rate
|
| 359 |
+
augmentations: List of augmentations to apply, or None for random
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Augmented audio
|
| 363 |
+
"""
|
| 364 |
+
if augmentations is None:
|
| 365 |
+
# Randomly select augmentations
|
| 366 |
+
augmentations = []
|
| 367 |
+
if self.config.add_noise and random.random() < 0.7:
|
| 368 |
+
augmentations.append('noise')
|
| 369 |
+
if self.config.time_stretch and random.random() < 0.3:
|
| 370 |
+
augmentations.append('time_stretch')
|
| 371 |
+
if self.config.pitch_shift and random.random() < 0.3:
|
| 372 |
+
augmentations.append('pitch_shift')
|
| 373 |
+
if self.config.amplitude_variation and random.random() < 0.5:
|
| 374 |
+
augmentations.append('gain')
|
| 375 |
+
|
| 376 |
+
augmented = audio.copy()
|
| 377 |
+
|
| 378 |
+
for aug in augmentations:
|
| 379 |
+
if aug == 'noise':
|
| 380 |
+
augmented = self.add_noise(augmented, sr)
|
| 381 |
+
elif aug == 'time_stretch':
|
| 382 |
+
augmented = self.time_stretch(augmented)
|
| 383 |
+
elif aug == 'pitch_shift':
|
| 384 |
+
augmented = self.pitch_shift(augmented, sr)
|
| 385 |
+
elif aug == 'gain':
|
| 386 |
+
augmented = self.apply_gain(augmented)
|
| 387 |
+
|
| 388 |
+
return augmented
|
| 389 |
+
|
| 390 |
+
def augment_spectrogram(
|
| 391 |
+
self,
|
| 392 |
+
spectrogram: np.ndarray
|
| 393 |
+
) -> np.ndarray:
|
| 394 |
+
"""
|
| 395 |
+
Apply SpecAugment-style augmentations to mel-spectrogram.
|
| 396 |
+
"""
|
| 397 |
+
augmented = spectrogram.copy()
|
| 398 |
+
|
| 399 |
+
if self.config.time_mask and random.random() < 0.5:
|
| 400 |
+
augmented = self.time_mask(augmented)
|
| 401 |
+
|
| 402 |
+
if self.config.freq_mask and random.random() < 0.5:
|
| 403 |
+
augmented = self.freq_mask(augmented)
|
| 404 |
+
|
| 405 |
+
return augmented
|
| 406 |
+
|
| 407 |
+
def create_challenging_sample(
|
| 408 |
+
self,
|
| 409 |
+
audio: np.ndarray,
|
| 410 |
+
sr: int,
|
| 411 |
+
challenge_type: str
|
| 412 |
+
) -> Tuple[np.ndarray, dict]:
|
| 413 |
+
"""
|
| 414 |
+
Create specifically challenging audio samples for testing.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
audio: Clean audio sample
|
| 418 |
+
sr: Sample rate
|
| 419 |
+
challenge_type: One of 'feeble', 'noisy', 'multi_source', 'brief'
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
Tuple of (augmented_audio, metadata)
|
| 423 |
+
"""
|
| 424 |
+
metadata = {"challenge_type": challenge_type}
|
| 425 |
+
|
| 426 |
+
if challenge_type == 'feeble':
|
| 427 |
+
# Simulate distant/quiet recording
|
| 428 |
+
gain_db = random.uniform(-20, -10)
|
| 429 |
+
audio = self.apply_gain(audio, gain_db)
|
| 430 |
+
audio = self.add_noise(audio, sr, 'pink', snr_db=random.uniform(5, 10))
|
| 431 |
+
metadata['gain_db'] = gain_db
|
| 432 |
+
|
| 433 |
+
elif challenge_type == 'noisy':
|
| 434 |
+
# Heavy noise contamination
|
| 435 |
+
noise_type = random.choice(['urban', 'forest'])
|
| 436 |
+
snr_db = random.uniform(0, 5)
|
| 437 |
+
audio = self.add_noise(audio, sr, noise_type, snr_db)
|
| 438 |
+
metadata['noise_type'] = noise_type
|
| 439 |
+
metadata['snr_db'] = snr_db
|
| 440 |
+
|
| 441 |
+
elif challenge_type == 'multi_source':
|
| 442 |
+
# Simulate multiple overlapping sounds (mix with shifted copy)
|
| 443 |
+
shifted = self.pitch_shift(audio, sr, random.uniform(-3, 3))
|
| 444 |
+
delay_samples = random.randint(0, len(audio) // 4)
|
| 445 |
+
delayed = np.roll(shifted, delay_samples)
|
| 446 |
+
audio = audio * 0.7 + delayed * 0.5
|
| 447 |
+
audio = self.add_noise(audio, sr, snr_db=random.uniform(10, 20))
|
| 448 |
+
metadata['n_sources'] = 2
|
| 449 |
+
|
| 450 |
+
elif challenge_type == 'brief':
|
| 451 |
+
# Very short call with silence padding
|
| 452 |
+
call_duration = random.uniform(0.3, 1.0)
|
| 453 |
+
call_samples = int(call_duration * sr)
|
| 454 |
+
if call_samples < len(audio):
|
| 455 |
+
start = random.randint(0, len(audio) - call_samples)
|
| 456 |
+
brief = np.zeros_like(audio)
|
| 457 |
+
insert_pos = random.randint(0, len(audio) - call_samples)
|
| 458 |
+
brief[insert_pos:insert_pos + call_samples] = audio[start:start + call_samples]
|
| 459 |
+
audio = brief
|
| 460 |
+
audio = self.add_noise(audio, sr, snr_db=random.uniform(10, 20))
|
| 461 |
+
metadata['call_duration'] = call_duration
|
| 462 |
+
|
| 463 |
+
return audio.astype(np.float32), metadata
|
| 464 |
+
|
audio/encoder.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightweight Audio Encoder for BirdSense.
|
| 3 |
+
|
| 4 |
+
Implements a small, efficient audio encoder optimized for bird call recognition.
|
| 5 |
+
Designed for edge deployment while maintaining competitive accuracy.
|
| 6 |
+
|
| 7 |
+
Architecture options:
|
| 8 |
+
1. AST-Tiny: Audio Spectrogram Transformer (small variant)
|
| 9 |
+
2. EfficientNet-B0: Adapted for spectrograms
|
| 10 |
+
3. MobileViT: Vision transformer for mobile
|
| 11 |
+
4. Custom CNN: Lightweight convolutional network
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from typing import Optional, Tuple
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConvBlock(nn.Module):
|
| 22 |
+
"""Convolutional block with batch norm and activation."""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
in_channels: int,
|
| 27 |
+
out_channels: int,
|
| 28 |
+
kernel_size: int = 3,
|
| 29 |
+
stride: int = 1,
|
| 30 |
+
padding: int = 1
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.conv = nn.Conv2d(
|
| 34 |
+
in_channels, out_channels,
|
| 35 |
+
kernel_size, stride, padding,
|
| 36 |
+
bias=False
|
| 37 |
+
)
|
| 38 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 39 |
+
self.act = nn.SiLU(inplace=True) # Swish activation
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return self.act(self.bn(self.conv(x)))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SqueezeExcitation(nn.Module):
|
| 46 |
+
"""Squeeze-and-Excitation attention block."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, channels: int, reduction: int = 4):
|
| 49 |
+
super().__init__()
|
| 50 |
+
reduced = max(1, channels // reduction)
|
| 51 |
+
self.fc1 = nn.Conv2d(channels, reduced, 1)
|
| 52 |
+
self.fc2 = nn.Conv2d(reduced, channels, 1)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
scale = F.adaptive_avg_pool2d(x, 1)
|
| 56 |
+
scale = F.silu(self.fc1(scale))
|
| 57 |
+
scale = torch.sigmoid(self.fc2(scale))
|
| 58 |
+
return x * scale
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MBConv(nn.Module):
|
| 62 |
+
"""Mobile Inverted Bottleneck Conv (from EfficientNet)."""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
in_channels: int,
|
| 67 |
+
out_channels: int,
|
| 68 |
+
expand_ratio: int = 4,
|
| 69 |
+
stride: int = 1,
|
| 70 |
+
se_ratio: float = 0.25
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.stride = stride
|
| 74 |
+
self.use_residual = stride == 1 and in_channels == out_channels
|
| 75 |
+
|
| 76 |
+
hidden_dim = in_channels * expand_ratio
|
| 77 |
+
|
| 78 |
+
layers = []
|
| 79 |
+
|
| 80 |
+
# Expansion
|
| 81 |
+
if expand_ratio != 1:
|
| 82 |
+
layers.extend([
|
| 83 |
+
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
|
| 84 |
+
nn.BatchNorm2d(hidden_dim),
|
| 85 |
+
nn.SiLU(inplace=True)
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
# Depthwise conv
|
| 89 |
+
layers.extend([
|
| 90 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 91 |
+
nn.BatchNorm2d(hidden_dim),
|
| 92 |
+
nn.SiLU(inplace=True)
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
# Squeeze-and-Excitation
|
| 96 |
+
if se_ratio > 0:
|
| 97 |
+
layers.append(SqueezeExcitation(hidden_dim, int(1 / se_ratio)))
|
| 98 |
+
|
| 99 |
+
# Projection
|
| 100 |
+
layers.extend([
|
| 101 |
+
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
|
| 102 |
+
nn.BatchNorm2d(out_channels)
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
self.conv = nn.Sequential(*layers)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
if self.use_residual:
|
| 109 |
+
return x + self.conv(x)
|
| 110 |
+
return self.conv(x)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class BirdAudioEncoder(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
Lightweight audio encoder for bird sound recognition.
|
| 116 |
+
|
| 117 |
+
Takes mel-spectrogram input and produces embeddings.
|
| 118 |
+
Designed for efficiency while maintaining good accuracy.
|
| 119 |
+
|
| 120 |
+
Architecture: Custom efficient CNN inspired by EfficientNet-B0/MobileNetV3
|
| 121 |
+
Parameters: ~2M (very lightweight)
|
| 122 |
+
Input: Mel-spectrogram (1, n_mels, n_frames)
|
| 123 |
+
Output: Embedding vector (embedding_dim,)
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
n_mels: int = 128,
|
| 129 |
+
embedding_dim: int = 384,
|
| 130 |
+
width_multiplier: float = 1.0
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.n_mels = n_mels
|
| 135 |
+
self.embedding_dim = embedding_dim
|
| 136 |
+
|
| 137 |
+
def _make_divisible(v):
|
| 138 |
+
"""Round to nearest multiple of 8."""
|
| 139 |
+
new_v = max(8, int(v * width_multiplier + 4) // 8 * 8)
|
| 140 |
+
if new_v < 0.9 * v * width_multiplier:
|
| 141 |
+
new_v += 8
|
| 142 |
+
return new_v
|
| 143 |
+
|
| 144 |
+
# Stem
|
| 145 |
+
self.stem = ConvBlock(1, _make_divisible(32), 3, 2, 1)
|
| 146 |
+
|
| 147 |
+
# Main blocks
|
| 148 |
+
self.blocks = nn.Sequential(
|
| 149 |
+
# Stage 1
|
| 150 |
+
MBConv(_make_divisible(32), _make_divisible(16), expand_ratio=1, stride=1),
|
| 151 |
+
|
| 152 |
+
# Stage 2
|
| 153 |
+
MBConv(_make_divisible(16), _make_divisible(24), expand_ratio=4, stride=2),
|
| 154 |
+
MBConv(_make_divisible(24), _make_divisible(24), expand_ratio=4, stride=1),
|
| 155 |
+
|
| 156 |
+
# Stage 3
|
| 157 |
+
MBConv(_make_divisible(24), _make_divisible(40), expand_ratio=4, stride=2),
|
| 158 |
+
MBConv(_make_divisible(40), _make_divisible(40), expand_ratio=4, stride=1),
|
| 159 |
+
|
| 160 |
+
# Stage 4
|
| 161 |
+
MBConv(_make_divisible(40), _make_divisible(80), expand_ratio=4, stride=2),
|
| 162 |
+
MBConv(_make_divisible(80), _make_divisible(80), expand_ratio=4, stride=1),
|
| 163 |
+
MBConv(_make_divisible(80), _make_divisible(80), expand_ratio=4, stride=1),
|
| 164 |
+
|
| 165 |
+
# Stage 5
|
| 166 |
+
MBConv(_make_divisible(80), _make_divisible(112), expand_ratio=4, stride=1),
|
| 167 |
+
MBConv(_make_divisible(112), _make_divisible(112), expand_ratio=4, stride=1),
|
| 168 |
+
|
| 169 |
+
# Stage 6
|
| 170 |
+
MBConv(_make_divisible(112), _make_divisible(192), expand_ratio=4, stride=2),
|
| 171 |
+
MBConv(_make_divisible(192), _make_divisible(192), expand_ratio=4, stride=1),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Head
|
| 175 |
+
self.head = nn.Sequential(
|
| 176 |
+
ConvBlock(_make_divisible(192), _make_divisible(320), 1, 1, 0),
|
| 177 |
+
nn.AdaptiveAvgPool2d(1),
|
| 178 |
+
nn.Flatten(),
|
| 179 |
+
nn.Linear(_make_divisible(320), embedding_dim),
|
| 180 |
+
nn.LayerNorm(embedding_dim)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Initialize weights
|
| 184 |
+
self._init_weights()
|
| 185 |
+
|
| 186 |
+
def _init_weights(self):
|
| 187 |
+
for m in self.modules():
|
| 188 |
+
if isinstance(m, nn.Conv2d):
|
| 189 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 190 |
+
if m.bias is not None:
|
| 191 |
+
nn.init.zeros_(m.bias)
|
| 192 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 193 |
+
nn.init.ones_(m.weight)
|
| 194 |
+
nn.init.zeros_(m.bias)
|
| 195 |
+
elif isinstance(m, nn.Linear):
|
| 196 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 197 |
+
if m.bias is not None:
|
| 198 |
+
nn.init.zeros_(m.bias)
|
| 199 |
+
|
| 200 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Forward pass.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
x: Mel-spectrogram tensor of shape (batch, n_mels, n_frames)
|
| 206 |
+
or (batch, 1, n_mels, n_frames)
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Embedding tensor of shape (batch, embedding_dim)
|
| 210 |
+
"""
|
| 211 |
+
# Add channel dimension if needed
|
| 212 |
+
if x.dim() == 3:
|
| 213 |
+
x = x.unsqueeze(1) # (B, 1, n_mels, n_frames)
|
| 214 |
+
|
| 215 |
+
x = self.stem(x)
|
| 216 |
+
x = self.blocks(x)
|
| 217 |
+
x = self.head(x)
|
| 218 |
+
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
def get_embedding_dim(self) -> int:
|
| 222 |
+
return self.embedding_dim
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class PositionalEncoding(nn.Module):
|
| 226 |
+
"""Sinusoidal positional encoding for transformer."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, d_model: int, max_len: int = 5000):
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
pe = torch.zeros(max_len, d_model)
|
| 232 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 233 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 234 |
+
|
| 235 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 236 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 237 |
+
|
| 238 |
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
| 239 |
+
self.register_buffer('pe', pe)
|
| 240 |
+
|
| 241 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 242 |
+
return x + self.pe[:, :x.size(1)]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class PatchEmbed(nn.Module):
|
| 246 |
+
"""Convert spectrogram to patch embeddings."""
|
| 247 |
+
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
img_size: Tuple[int, int] = (128, 500),
|
| 251 |
+
patch_size: Tuple[int, int] = (16, 16),
|
| 252 |
+
in_channels: int = 1,
|
| 253 |
+
embed_dim: int = 384
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.img_size = img_size
|
| 257 |
+
self.patch_size = patch_size
|
| 258 |
+
self.n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
| 259 |
+
|
| 260 |
+
self.proj = nn.Conv2d(
|
| 261 |
+
in_channels, embed_dim,
|
| 262 |
+
kernel_size=patch_size,
|
| 263 |
+
stride=patch_size
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 267 |
+
x = self.proj(x) # (B, embed_dim, H', W')
|
| 268 |
+
x = x.flatten(2) # (B, embed_dim, n_patches)
|
| 269 |
+
x = x.transpose(1, 2) # (B, n_patches, embed_dim)
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class AudioTransformerEncoder(nn.Module):
|
| 274 |
+
"""
|
| 275 |
+
Small Audio Spectrogram Transformer (AST) variant.
|
| 276 |
+
|
| 277 |
+
Inspired by the original AST but significantly smaller for edge deployment.
|
| 278 |
+
Parameters: ~8M (still lightweight)
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
n_mels: int = 128,
|
| 284 |
+
max_frames: int = 500,
|
| 285 |
+
patch_size: Tuple[int, int] = (16, 16),
|
| 286 |
+
embed_dim: int = 384,
|
| 287 |
+
depth: int = 6,
|
| 288 |
+
num_heads: int = 6,
|
| 289 |
+
mlp_ratio: float = 4.0,
|
| 290 |
+
dropout: float = 0.1
|
| 291 |
+
):
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
self.embed_dim = embed_dim
|
| 295 |
+
|
| 296 |
+
# Patch embedding
|
| 297 |
+
self.patch_embed = PatchEmbed(
|
| 298 |
+
img_size=(n_mels, max_frames),
|
| 299 |
+
patch_size=patch_size,
|
| 300 |
+
in_channels=1,
|
| 301 |
+
embed_dim=embed_dim
|
| 302 |
+
)
|
| 303 |
+
n_patches = self.patch_embed.n_patches
|
| 304 |
+
|
| 305 |
+
# CLS token and positional embedding
|
| 306 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 307 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
|
| 308 |
+
self.pos_drop = nn.Dropout(dropout)
|
| 309 |
+
|
| 310 |
+
# Transformer encoder
|
| 311 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 312 |
+
d_model=embed_dim,
|
| 313 |
+
nhead=num_heads,
|
| 314 |
+
dim_feedforward=int(embed_dim * mlp_ratio),
|
| 315 |
+
dropout=dropout,
|
| 316 |
+
activation='gelu',
|
| 317 |
+
batch_first=True,
|
| 318 |
+
norm_first=True
|
| 319 |
+
)
|
| 320 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
| 321 |
+
|
| 322 |
+
# Output norm
|
| 323 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 324 |
+
|
| 325 |
+
# Initialize
|
| 326 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
| 327 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 328 |
+
|
| 329 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 330 |
+
"""
|
| 331 |
+
Args:
|
| 332 |
+
x: Mel-spectrogram (batch, n_mels, n_frames) or (batch, 1, n_mels, n_frames)
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Embedding (batch, embed_dim)
|
| 336 |
+
"""
|
| 337 |
+
if x.dim() == 3:
|
| 338 |
+
x = x.unsqueeze(1)
|
| 339 |
+
|
| 340 |
+
# Pad to expected size if needed
|
| 341 |
+
_, _, h, w = x.shape
|
| 342 |
+
target_h, target_w = self.patch_embed.img_size
|
| 343 |
+
|
| 344 |
+
if h != target_h or w != target_w:
|
| 345 |
+
x = F.interpolate(x, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
| 346 |
+
|
| 347 |
+
# Patch embed
|
| 348 |
+
x = self.patch_embed(x) # (B, n_patches, embed_dim)
|
| 349 |
+
|
| 350 |
+
# Add CLS token
|
| 351 |
+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
| 352 |
+
x = torch.cat([cls_tokens, x], dim=1)
|
| 353 |
+
|
| 354 |
+
# Add positional embedding
|
| 355 |
+
x = x + self.pos_embed
|
| 356 |
+
x = self.pos_drop(x)
|
| 357 |
+
|
| 358 |
+
# Transformer
|
| 359 |
+
x = self.encoder(x)
|
| 360 |
+
x = self.norm(x)
|
| 361 |
+
|
| 362 |
+
# Return CLS token embedding
|
| 363 |
+
return x[:, 0]
|
| 364 |
+
|
| 365 |
+
def get_embedding_dim(self) -> int:
|
| 366 |
+
return self.embed_dim
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class AudioEncoder(nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
Unified audio encoder interface.
|
| 372 |
+
|
| 373 |
+
Supports multiple backbone architectures:
|
| 374 |
+
- 'cnn': Lightweight CNN (BirdAudioEncoder)
|
| 375 |
+
- 'ast_tiny': Small AST transformer
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
ARCHITECTURES = {
|
| 379 |
+
'cnn': BirdAudioEncoder,
|
| 380 |
+
'ast_tiny': AudioTransformerEncoder
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
architecture: str = 'cnn',
|
| 386 |
+
n_mels: int = 128,
|
| 387 |
+
embedding_dim: int = 384,
|
| 388 |
+
pretrained: bool = False,
|
| 389 |
+
**kwargs
|
| 390 |
+
):
|
| 391 |
+
super().__init__()
|
| 392 |
+
|
| 393 |
+
if architecture not in self.ARCHITECTURES:
|
| 394 |
+
raise ValueError(f"Unknown architecture: {architecture}. "
|
| 395 |
+
f"Choose from: {list(self.ARCHITECTURES.keys())}")
|
| 396 |
+
|
| 397 |
+
encoder_cls = self.ARCHITECTURES[architecture]
|
| 398 |
+
self.encoder = encoder_cls(n_mels=n_mels, embedding_dim=embedding_dim, **kwargs)
|
| 399 |
+
self.embedding_dim = embedding_dim
|
| 400 |
+
|
| 401 |
+
if pretrained:
|
| 402 |
+
self._load_pretrained(architecture)
|
| 403 |
+
|
| 404 |
+
def _load_pretrained(self, architecture: str):
|
| 405 |
+
"""Load pretrained weights if available."""
|
| 406 |
+
# TODO: Implement pretrained weight loading from checkpoints
|
| 407 |
+
pass
|
| 408 |
+
|
| 409 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 410 |
+
return self.encoder(x)
|
| 411 |
+
|
| 412 |
+
def get_embedding_dim(self) -> int:
|
| 413 |
+
return self.embedding_dim
|
| 414 |
+
|
| 415 |
+
@torch.no_grad()
|
| 416 |
+
def extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 417 |
+
"""Extract features without gradient computation."""
|
| 418 |
+
self.eval()
|
| 419 |
+
return self.forward(x)
|
| 420 |
+
|
| 421 |
+
def count_parameters(self) -> int:
|
| 422 |
+
"""Count trainable parameters."""
|
| 423 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 424 |
+
|
audio/preprocessor.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Preprocessing Pipeline for BirdSense.
|
| 3 |
+
|
| 4 |
+
Handles:
|
| 5 |
+
- Audio loading and resampling
|
| 6 |
+
- Spectrogram generation (mel-spectrogram)
|
| 7 |
+
- Noise reduction for challenging recordings
|
| 8 |
+
- Amplitude normalization
|
| 9 |
+
- Chunk splitting for long recordings
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import librosa
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
from scipy import signal
|
| 16 |
+
from typing import Tuple, Optional, List
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import io
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AudioConfig:
|
| 23 |
+
"""Audio processing configuration."""
|
| 24 |
+
sample_rate: int = 32000
|
| 25 |
+
duration: float = 5.0
|
| 26 |
+
n_fft: int = 1024
|
| 27 |
+
hop_length: int = 320
|
| 28 |
+
n_mels: int = 128
|
| 29 |
+
fmin: int = 50
|
| 30 |
+
fmax: int = 14000
|
| 31 |
+
normalize: bool = True
|
| 32 |
+
noise_reduction: bool = True
|
| 33 |
+
noise_reduction_strength: float = 0.3
|
| 34 |
+
min_amplitude_db: float = -60
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AudioPreprocessor:
|
| 38 |
+
"""
|
| 39 |
+
Robust audio preprocessor for bird sound analysis.
|
| 40 |
+
|
| 41 |
+
Designed to handle:
|
| 42 |
+
- Feeble/distant bird calls
|
| 43 |
+
- Noisy urban/natural environments
|
| 44 |
+
- Multiple overlapping bird sounds
|
| 45 |
+
- Various audio formats and quality levels
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: Optional[AudioConfig] = None):
|
| 49 |
+
self.config = config or AudioConfig()
|
| 50 |
+
|
| 51 |
+
def load_audio(
|
| 52 |
+
self,
|
| 53 |
+
source: str | bytes | np.ndarray,
|
| 54 |
+
target_sr: Optional[int] = None
|
| 55 |
+
) -> Tuple[np.ndarray, int]:
|
| 56 |
+
"""
|
| 57 |
+
Load audio from file path, bytes, or numpy array.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
source: File path, raw bytes, or numpy array
|
| 61 |
+
target_sr: Target sample rate (uses config if None)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Tuple of (audio_waveform, sample_rate)
|
| 65 |
+
"""
|
| 66 |
+
target_sr = target_sr or self.config.sample_rate
|
| 67 |
+
|
| 68 |
+
if isinstance(source, np.ndarray):
|
| 69 |
+
# Already a numpy array
|
| 70 |
+
audio = source
|
| 71 |
+
sr = target_sr
|
| 72 |
+
elif isinstance(source, bytes):
|
| 73 |
+
# Load from bytes
|
| 74 |
+
audio, sr = sf.read(io.BytesIO(source))
|
| 75 |
+
else:
|
| 76 |
+
# Load from file path
|
| 77 |
+
audio, sr = librosa.load(source, sr=target_sr, mono=True)
|
| 78 |
+
return audio, sr
|
| 79 |
+
|
| 80 |
+
# Convert to mono if stereo
|
| 81 |
+
if len(audio.shape) > 1:
|
| 82 |
+
audio = np.mean(audio, axis=1)
|
| 83 |
+
|
| 84 |
+
# Resample if needed
|
| 85 |
+
if sr != target_sr:
|
| 86 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
| 87 |
+
sr = target_sr
|
| 88 |
+
|
| 89 |
+
return audio.astype(np.float32), sr
|
| 90 |
+
|
| 91 |
+
def normalize_audio(self, audio: np.ndarray) -> np.ndarray:
|
| 92 |
+
"""
|
| 93 |
+
Normalize audio amplitude.
|
| 94 |
+
Handles feeble recordings by boosting low amplitude signals.
|
| 95 |
+
"""
|
| 96 |
+
if len(audio) == 0:
|
| 97 |
+
return audio
|
| 98 |
+
|
| 99 |
+
# Peak normalization
|
| 100 |
+
max_val = np.max(np.abs(audio))
|
| 101 |
+
if max_val > 0:
|
| 102 |
+
audio = audio / max_val
|
| 103 |
+
|
| 104 |
+
# Boost feeble audio (adaptive gain)
|
| 105 |
+
rms = np.sqrt(np.mean(audio ** 2))
|
| 106 |
+
if rms < 0.1: # Feeble recording detected
|
| 107 |
+
target_rms = 0.2
|
| 108 |
+
gain = target_rms / (rms + 1e-8)
|
| 109 |
+
gain = min(gain, 10.0) # Limit gain to avoid noise amplification
|
| 110 |
+
audio = audio * gain
|
| 111 |
+
|
| 112 |
+
return np.clip(audio, -1.0, 1.0)
|
| 113 |
+
|
| 114 |
+
def reduce_noise(
|
| 115 |
+
self,
|
| 116 |
+
audio: np.ndarray,
|
| 117 |
+
sr: int,
|
| 118 |
+
strength: Optional[float] = None
|
| 119 |
+
) -> np.ndarray:
|
| 120 |
+
"""
|
| 121 |
+
Apply spectral noise reduction.
|
| 122 |
+
|
| 123 |
+
Uses spectral gating to reduce background noise while
|
| 124 |
+
preserving bird call frequencies.
|
| 125 |
+
"""
|
| 126 |
+
strength = strength or self.config.noise_reduction_strength
|
| 127 |
+
|
| 128 |
+
if len(audio) < sr * 0.1: # Too short
|
| 129 |
+
return audio
|
| 130 |
+
|
| 131 |
+
# Compute STFT
|
| 132 |
+
stft = librosa.stft(audio, n_fft=self.config.n_fft, hop_length=self.config.hop_length)
|
| 133 |
+
magnitude = np.abs(stft)
|
| 134 |
+
phase = np.angle(stft)
|
| 135 |
+
|
| 136 |
+
# Estimate noise floor from quietest frames
|
| 137 |
+
frame_energy = np.sum(magnitude ** 2, axis=0)
|
| 138 |
+
noise_frames = frame_energy < np.percentile(frame_energy, 20)
|
| 139 |
+
|
| 140 |
+
if np.sum(noise_frames) > 0:
|
| 141 |
+
noise_profile = np.mean(magnitude[:, noise_frames], axis=1, keepdims=True)
|
| 142 |
+
else:
|
| 143 |
+
noise_profile = np.min(magnitude, axis=1, keepdims=True)
|
| 144 |
+
|
| 145 |
+
# Spectral subtraction with oversubtraction factor
|
| 146 |
+
alpha = 1.0 + strength
|
| 147 |
+
magnitude_clean = magnitude - alpha * noise_profile
|
| 148 |
+
magnitude_clean = np.maximum(magnitude_clean, magnitude * 0.1) # Keep some residual
|
| 149 |
+
|
| 150 |
+
# Reconstruct
|
| 151 |
+
stft_clean = magnitude_clean * np.exp(1j * phase)
|
| 152 |
+
audio_clean = librosa.istft(stft_clean, hop_length=self.config.hop_length, length=len(audio))
|
| 153 |
+
|
| 154 |
+
return audio_clean.astype(np.float32)
|
| 155 |
+
|
| 156 |
+
def apply_bandpass(
|
| 157 |
+
self,
|
| 158 |
+
audio: np.ndarray,
|
| 159 |
+
sr: int,
|
| 160 |
+
low_freq: Optional[int] = None,
|
| 161 |
+
high_freq: Optional[int] = None
|
| 162 |
+
) -> np.ndarray:
|
| 163 |
+
"""
|
| 164 |
+
Apply bandpass filter to focus on bird vocalization frequencies.
|
| 165 |
+
Most bird calls are between 500Hz - 10kHz.
|
| 166 |
+
"""
|
| 167 |
+
low_freq = low_freq or self.config.fmin
|
| 168 |
+
high_freq = high_freq or min(self.config.fmax, sr // 2 - 100)
|
| 169 |
+
|
| 170 |
+
nyquist = sr / 2
|
| 171 |
+
low = low_freq / nyquist
|
| 172 |
+
high = high_freq / nyquist
|
| 173 |
+
|
| 174 |
+
# Butterworth bandpass filter
|
| 175 |
+
b, a = signal.butter(4, [low, high], btype='band')
|
| 176 |
+
audio_filtered = signal.filtfilt(b, a, audio)
|
| 177 |
+
|
| 178 |
+
return audio_filtered.astype(np.float32)
|
| 179 |
+
|
| 180 |
+
def compute_melspectrogram(
|
| 181 |
+
self,
|
| 182 |
+
audio: np.ndarray,
|
| 183 |
+
sr: int
|
| 184 |
+
) -> np.ndarray:
|
| 185 |
+
"""
|
| 186 |
+
Compute mel-spectrogram optimized for bird calls.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Mel-spectrogram with shape (n_mels, time_frames)
|
| 190 |
+
"""
|
| 191 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 192 |
+
y=audio,
|
| 193 |
+
sr=sr,
|
| 194 |
+
n_fft=self.config.n_fft,
|
| 195 |
+
hop_length=self.config.hop_length,
|
| 196 |
+
n_mels=self.config.n_mels,
|
| 197 |
+
fmin=self.config.fmin,
|
| 198 |
+
fmax=min(self.config.fmax, sr // 2)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Convert to log scale (dB)
|
| 202 |
+
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 203 |
+
|
| 204 |
+
# Normalize to [0, 1] range for neural network input
|
| 205 |
+
mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
|
| 206 |
+
|
| 207 |
+
return mel_spec_norm.astype(np.float32)
|
| 208 |
+
|
| 209 |
+
def split_into_chunks(
|
| 210 |
+
self,
|
| 211 |
+
audio: np.ndarray,
|
| 212 |
+
sr: int,
|
| 213 |
+
overlap: float = 0.5
|
| 214 |
+
) -> List[np.ndarray]:
|
| 215 |
+
"""
|
| 216 |
+
Split long audio into overlapping chunks for processing.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
audio: Input audio waveform
|
| 220 |
+
sr: Sample rate
|
| 221 |
+
overlap: Overlap ratio between chunks (0.0 - 1.0)
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
List of audio chunks
|
| 225 |
+
"""
|
| 226 |
+
chunk_samples = int(self.config.duration * sr)
|
| 227 |
+
hop_samples = int(chunk_samples * (1 - overlap))
|
| 228 |
+
|
| 229 |
+
if len(audio) <= chunk_samples:
|
| 230 |
+
# Pad short audio
|
| 231 |
+
if len(audio) < chunk_samples:
|
| 232 |
+
audio = np.pad(audio, (0, chunk_samples - len(audio)), mode='constant')
|
| 233 |
+
return [audio]
|
| 234 |
+
|
| 235 |
+
chunks = []
|
| 236 |
+
start = 0
|
| 237 |
+
while start < len(audio):
|
| 238 |
+
end = start + chunk_samples
|
| 239 |
+
chunk = audio[start:end]
|
| 240 |
+
|
| 241 |
+
# Pad last chunk if needed
|
| 242 |
+
if len(chunk) < chunk_samples:
|
| 243 |
+
chunk = np.pad(chunk, (0, chunk_samples - len(chunk)), mode='constant')
|
| 244 |
+
|
| 245 |
+
chunks.append(chunk)
|
| 246 |
+
start += hop_samples
|
| 247 |
+
|
| 248 |
+
return chunks
|
| 249 |
+
|
| 250 |
+
def process(
|
| 251 |
+
self,
|
| 252 |
+
source: str | bytes | np.ndarray,
|
| 253 |
+
return_waveform: bool = False
|
| 254 |
+
) -> dict:
|
| 255 |
+
"""
|
| 256 |
+
Full preprocessing pipeline.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
source: Audio file path, bytes, or numpy array
|
| 260 |
+
return_waveform: Include processed waveform in output
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Dictionary with processed audio data:
|
| 264 |
+
- mel_specs: List of mel-spectrograms for each chunk
|
| 265 |
+
- waveforms: List of audio chunks (if return_waveform=True)
|
| 266 |
+
- duration: Total audio duration
|
| 267 |
+
- sample_rate: Sample rate
|
| 268 |
+
- num_chunks: Number of audio chunks
|
| 269 |
+
"""
|
| 270 |
+
# Load audio
|
| 271 |
+
audio, sr = self.load_audio(source)
|
| 272 |
+
original_duration = len(audio) / sr
|
| 273 |
+
|
| 274 |
+
# Apply bandpass filter
|
| 275 |
+
audio = self.apply_bandpass(audio, sr)
|
| 276 |
+
|
| 277 |
+
# Noise reduction (if enabled)
|
| 278 |
+
if self.config.noise_reduction:
|
| 279 |
+
audio = self.reduce_noise(audio, sr)
|
| 280 |
+
|
| 281 |
+
# Normalize
|
| 282 |
+
if self.config.normalize:
|
| 283 |
+
audio = self.normalize_audio(audio)
|
| 284 |
+
|
| 285 |
+
# Split into chunks
|
| 286 |
+
chunks = self.split_into_chunks(audio, sr)
|
| 287 |
+
|
| 288 |
+
# Compute mel-spectrograms
|
| 289 |
+
mel_specs = [self.compute_melspectrogram(chunk, sr) for chunk in chunks]
|
| 290 |
+
|
| 291 |
+
result = {
|
| 292 |
+
"mel_specs": mel_specs,
|
| 293 |
+
"duration": original_duration,
|
| 294 |
+
"sample_rate": sr,
|
| 295 |
+
"num_chunks": len(chunks),
|
| 296 |
+
"chunk_duration": self.config.duration
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
if return_waveform:
|
| 300 |
+
result["waveforms"] = chunks
|
| 301 |
+
|
| 302 |
+
return result
|
| 303 |
+
|
| 304 |
+
def get_audio_quality_assessment(self, audio: np.ndarray, sr: int) -> dict:
|
| 305 |
+
"""
|
| 306 |
+
Assess audio quality for diagnostic purposes.
|
| 307 |
+
|
| 308 |
+
Returns quality metrics useful for understanding
|
| 309 |
+
why recognition might succeed or fail.
|
| 310 |
+
"""
|
| 311 |
+
# RMS amplitude
|
| 312 |
+
rms = np.sqrt(np.mean(audio ** 2))
|
| 313 |
+
rms_db = 20 * np.log10(rms + 1e-8)
|
| 314 |
+
|
| 315 |
+
# Peak amplitude
|
| 316 |
+
peak = np.max(np.abs(audio))
|
| 317 |
+
peak_db = 20 * np.log10(peak + 1e-8)
|
| 318 |
+
|
| 319 |
+
# Signal-to-noise estimate (using spectral flatness)
|
| 320 |
+
mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr)
|
| 321 |
+
spectral_flatness = np.mean(librosa.feature.spectral_flatness(S=mel_spec))
|
| 322 |
+
estimated_snr = -10 * np.log10(spectral_flatness + 1e-8)
|
| 323 |
+
|
| 324 |
+
# Clipping detection
|
| 325 |
+
clipping_ratio = np.mean(np.abs(audio) > 0.99)
|
| 326 |
+
|
| 327 |
+
# Activity detection (voice activity equivalent for birds)
|
| 328 |
+
frame_energy = librosa.feature.rms(y=audio)[0]
|
| 329 |
+
activity_ratio = np.mean(frame_energy > np.percentile(frame_energy, 30))
|
| 330 |
+
|
| 331 |
+
quality_score = min(1.0, max(0.0,
|
| 332 |
+
0.3 * (1 - clipping_ratio) +
|
| 333 |
+
0.3 * min(1.0, estimated_snr / 20) +
|
| 334 |
+
0.2 * min(1.0, (rms_db + 40) / 30) +
|
| 335 |
+
0.2 * activity_ratio
|
| 336 |
+
))
|
| 337 |
+
|
| 338 |
+
return {
|
| 339 |
+
"rms_db": float(rms_db),
|
| 340 |
+
"peak_db": float(peak_db),
|
| 341 |
+
"estimated_snr_db": float(estimated_snr),
|
| 342 |
+
"clipping_ratio": float(clipping_ratio),
|
| 343 |
+
"activity_ratio": float(activity_ratio),
|
| 344 |
+
"quality_score": float(quality_score),
|
| 345 |
+
"quality_label": self._quality_label(quality_score)
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
def _quality_label(self, score: float) -> str:
|
| 349 |
+
if score >= 0.8:
|
| 350 |
+
return "excellent"
|
| 351 |
+
elif score >= 0.6:
|
| 352 |
+
return "good"
|
| 353 |
+
elif score >= 0.4:
|
| 354 |
+
return "fair"
|
| 355 |
+
elif score >= 0.2:
|
| 356 |
+
return "poor"
|
| 357 |
+
else:
|
| 358 |
+
return "very_poor"
|
| 359 |
+
|
audio/sam_audio.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM-Audio Integration for BirdSense.
|
| 3 |
+
|
| 4 |
+
Integrates Meta's SAM-Audio (Segment Anything in Audio) model for:
|
| 5 |
+
- Audio source separation
|
| 6 |
+
- Isolating bird calls from background noise
|
| 7 |
+
- Handling multi-bird chorus scenarios
|
| 8 |
+
- Improving recognition accuracy in challenging conditions
|
| 9 |
+
|
| 10 |
+
References:
|
| 11 |
+
- Paper: https://ai.meta.com/research/publications/sam-audio-segment-anything-in-audio/
|
| 12 |
+
- Model: https://huggingface.co/facebook/sam-audio-large
|
| 13 |
+
- Demo: https://ai.meta.com/samaudio/
|
| 14 |
+
|
| 15 |
+
SAM-Audio uses multimodal prompts (text, audio, point) to segment audio,
|
| 16 |
+
making it ideal for isolating specific bird calls.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import numpy as np
|
| 22 |
+
from typing import Optional, List, Dict, Tuple, Any
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
import logging
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class SAMAudioConfig:
|
| 32 |
+
"""Configuration for SAM-Audio integration."""
|
| 33 |
+
model_name: str = "facebook/sam-audio-large"
|
| 34 |
+
device: str = "auto"
|
| 35 |
+
cache_dir: str = ".cache/sam_audio"
|
| 36 |
+
|
| 37 |
+
# Separation settings
|
| 38 |
+
num_sources: int = 4 # Max number of sources to separate
|
| 39 |
+
min_source_energy: float = 0.01 # Minimum energy threshold
|
| 40 |
+
|
| 41 |
+
# Bird-specific settings
|
| 42 |
+
bird_frequency_range: Tuple[int, int] = (500, 10000) # Hz
|
| 43 |
+
use_text_prompt: bool = True # Use text prompts like "bird call"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SAMAudioProcessor:
|
| 47 |
+
"""
|
| 48 |
+
SAM-Audio processor for bird call isolation.
|
| 49 |
+
|
| 50 |
+
Uses Meta's SAM-Audio model to:
|
| 51 |
+
1. Separate overlapping audio sources
|
| 52 |
+
2. Isolate bird calls from background
|
| 53 |
+
3. Handle multi-bird recordings
|
| 54 |
+
4. Improve SNR for feeble recordings
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, config: Optional[SAMAudioConfig] = None):
|
| 58 |
+
self.config = config or SAMAudioConfig()
|
| 59 |
+
self.model = None
|
| 60 |
+
self.processor = None
|
| 61 |
+
self.device = None
|
| 62 |
+
self._model_loaded = False
|
| 63 |
+
|
| 64 |
+
def _setup_device(self):
|
| 65 |
+
"""Setup compute device."""
|
| 66 |
+
if self.config.device == "auto":
|
| 67 |
+
if torch.cuda.is_available():
|
| 68 |
+
self.device = torch.device("cuda")
|
| 69 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 70 |
+
self.device = torch.device("mps")
|
| 71 |
+
else:
|
| 72 |
+
self.device = torch.device("cpu")
|
| 73 |
+
else:
|
| 74 |
+
self.device = torch.device(self.config.device)
|
| 75 |
+
|
| 76 |
+
logger.info(f"SAM-Audio using device: {self.device}")
|
| 77 |
+
|
| 78 |
+
def load_model(self) -> bool:
|
| 79 |
+
"""
|
| 80 |
+
Load SAM-Audio model from HuggingFace.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
True if model loaded successfully
|
| 84 |
+
"""
|
| 85 |
+
if self._model_loaded:
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
self._setup_device()
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
# Try to load from transformers
|
| 92 |
+
from transformers import AutoModel, AutoProcessor
|
| 93 |
+
|
| 94 |
+
logger.info(f"Loading SAM-Audio model: {self.config.model_name}")
|
| 95 |
+
|
| 96 |
+
cache_dir = Path(self.config.cache_dir)
|
| 97 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 100 |
+
self.config.model_name,
|
| 101 |
+
cache_dir=str(cache_dir)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.model = AutoModel.from_pretrained(
|
| 105 |
+
self.config.model_name,
|
| 106 |
+
cache_dir=str(cache_dir)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.model.to(self.device)
|
| 110 |
+
self.model.eval()
|
| 111 |
+
|
| 112 |
+
self._model_loaded = True
|
| 113 |
+
logger.info("SAM-Audio model loaded successfully")
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
except ImportError:
|
| 117 |
+
logger.warning("transformers library not available for SAM-Audio")
|
| 118 |
+
return False
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.warning(f"Failed to load SAM-Audio: {e}")
|
| 121 |
+
logger.info("Falling back to spectral separation method")
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def is_available(self) -> bool:
|
| 125 |
+
"""Check if SAM-Audio is available."""
|
| 126 |
+
return self._model_loaded
|
| 127 |
+
|
| 128 |
+
@torch.no_grad()
|
| 129 |
+
def separate_sources(
|
| 130 |
+
self,
|
| 131 |
+
audio: np.ndarray,
|
| 132 |
+
sample_rate: int,
|
| 133 |
+
text_prompts: Optional[List[str]] = None
|
| 134 |
+
) -> List[Dict[str, Any]]:
|
| 135 |
+
"""
|
| 136 |
+
Separate audio into individual sources.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
audio: Input audio waveform
|
| 140 |
+
sample_rate: Sample rate
|
| 141 |
+
text_prompts: Optional text prompts like ["bird call", "wind"]
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of separated sources with metadata
|
| 145 |
+
"""
|
| 146 |
+
if not self._model_loaded:
|
| 147 |
+
# Fallback to spectral separation
|
| 148 |
+
return self._spectral_separation(audio, sample_rate)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
# Prepare input for SAM-Audio
|
| 152 |
+
if text_prompts is None:
|
| 153 |
+
text_prompts = ["bird vocalization", "background noise"]
|
| 154 |
+
|
| 155 |
+
# Process through model
|
| 156 |
+
inputs = self.processor(
|
| 157 |
+
audio,
|
| 158 |
+
sampling_rate=sample_rate,
|
| 159 |
+
text=text_prompts,
|
| 160 |
+
return_tensors="pt"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 164 |
+
|
| 165 |
+
outputs = self.model(**inputs)
|
| 166 |
+
|
| 167 |
+
# Extract separated sources
|
| 168 |
+
sources = []
|
| 169 |
+
for i, mask in enumerate(outputs.masks):
|
| 170 |
+
separated = audio * mask.cpu().numpy()
|
| 171 |
+
energy = np.mean(separated ** 2)
|
| 172 |
+
|
| 173 |
+
if energy > self.config.min_source_energy:
|
| 174 |
+
sources.append({
|
| 175 |
+
'audio': separated,
|
| 176 |
+
'energy': float(energy),
|
| 177 |
+
'label': text_prompts[i] if i < len(text_prompts) else f'source_{i}',
|
| 178 |
+
'mask': mask.cpu().numpy()
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return sources
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.warning(f"SAM-Audio separation failed: {e}")
|
| 185 |
+
return self._spectral_separation(audio, sample_rate)
|
| 186 |
+
|
| 187 |
+
def _spectral_separation(
|
| 188 |
+
self,
|
| 189 |
+
audio: np.ndarray,
|
| 190 |
+
sample_rate: int
|
| 191 |
+
) -> List[Dict[str, Any]]:
|
| 192 |
+
"""
|
| 193 |
+
Fallback spectral separation when SAM-Audio unavailable.
|
| 194 |
+
|
| 195 |
+
Uses spectral masking to separate bird frequency ranges
|
| 196 |
+
from background noise.
|
| 197 |
+
"""
|
| 198 |
+
import scipy.signal as signal
|
| 199 |
+
|
| 200 |
+
# Compute STFT
|
| 201 |
+
f, t, Zxx = signal.stft(audio, fs=sample_rate, nperseg=1024, noverlap=768)
|
| 202 |
+
magnitude = np.abs(Zxx)
|
| 203 |
+
phase = np.angle(Zxx)
|
| 204 |
+
|
| 205 |
+
# Create frequency masks
|
| 206 |
+
low_freq, high_freq = self.config.bird_frequency_range
|
| 207 |
+
|
| 208 |
+
# Bird frequency mask (500-10000 Hz)
|
| 209 |
+
bird_mask = (f >= low_freq) & (f <= high_freq)
|
| 210 |
+
bird_mask = bird_mask.astype(float).reshape(-1, 1)
|
| 211 |
+
|
| 212 |
+
# Apply soft masking
|
| 213 |
+
bird_magnitude = magnitude * bird_mask
|
| 214 |
+
background_magnitude = magnitude * (1 - bird_mask * 0.8)
|
| 215 |
+
|
| 216 |
+
# Reconstruct audio
|
| 217 |
+
bird_stft = bird_magnitude * np.exp(1j * phase)
|
| 218 |
+
_, bird_audio = signal.istft(bird_stft, fs=sample_rate, nperseg=1024, noverlap=768)
|
| 219 |
+
|
| 220 |
+
background_stft = background_magnitude * np.exp(1j * phase)
|
| 221 |
+
_, background_audio = signal.istft(background_stft, fs=sample_rate, nperseg=1024, noverlap=768)
|
| 222 |
+
|
| 223 |
+
# Ensure same length
|
| 224 |
+
min_len = min(len(audio), len(bird_audio), len(background_audio))
|
| 225 |
+
bird_audio = bird_audio[:min_len]
|
| 226 |
+
background_audio = background_audio[:min_len]
|
| 227 |
+
|
| 228 |
+
sources = [
|
| 229 |
+
{
|
| 230 |
+
'audio': bird_audio.astype(np.float32),
|
| 231 |
+
'energy': float(np.mean(bird_audio ** 2)),
|
| 232 |
+
'label': 'bird_frequencies',
|
| 233 |
+
'mask': bird_mask.flatten()
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
'audio': background_audio.astype(np.float32),
|
| 237 |
+
'energy': float(np.mean(background_audio ** 2)),
|
| 238 |
+
'label': 'background',
|
| 239 |
+
'mask': (1 - bird_mask).flatten()
|
| 240 |
+
}
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
return sources
|
| 244 |
+
|
| 245 |
+
def isolate_bird_call(
|
| 246 |
+
self,
|
| 247 |
+
audio: np.ndarray,
|
| 248 |
+
sample_rate: int
|
| 249 |
+
) -> Tuple[np.ndarray, float]:
|
| 250 |
+
"""
|
| 251 |
+
Isolate the primary bird call from audio.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
audio: Input audio
|
| 255 |
+
sample_rate: Sample rate
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
Tuple of (isolated_audio, quality_score)
|
| 259 |
+
"""
|
| 260 |
+
# Try SAM-Audio first
|
| 261 |
+
sources = self.separate_sources(
|
| 262 |
+
audio,
|
| 263 |
+
sample_rate,
|
| 264 |
+
text_prompts=["bird call", "bird song", "background noise", "wind"]
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Find the bird source
|
| 268 |
+
bird_source = None
|
| 269 |
+
max_bird_energy = 0
|
| 270 |
+
|
| 271 |
+
for source in sources:
|
| 272 |
+
label = source['label'].lower()
|
| 273 |
+
if 'bird' in label and source['energy'] > max_bird_energy:
|
| 274 |
+
bird_source = source
|
| 275 |
+
max_bird_energy = source['energy']
|
| 276 |
+
|
| 277 |
+
if bird_source is None:
|
| 278 |
+
# No clear bird source found, return original with spectral enhancement
|
| 279 |
+
return self._enhance_bird_frequencies(audio, sample_rate)
|
| 280 |
+
|
| 281 |
+
# Calculate quality improvement
|
| 282 |
+
original_energy = np.mean(audio ** 2)
|
| 283 |
+
isolated_energy = bird_source['energy']
|
| 284 |
+
quality_score = min(1.0, isolated_energy / (original_energy + 1e-8))
|
| 285 |
+
|
| 286 |
+
return bird_source['audio'], quality_score
|
| 287 |
+
|
| 288 |
+
def _enhance_bird_frequencies(
|
| 289 |
+
self,
|
| 290 |
+
audio: np.ndarray,
|
| 291 |
+
sample_rate: int
|
| 292 |
+
) -> Tuple[np.ndarray, float]:
|
| 293 |
+
"""Enhance bird frequency range in audio."""
|
| 294 |
+
import scipy.signal as signal
|
| 295 |
+
|
| 296 |
+
low_freq, high_freq = self.config.bird_frequency_range
|
| 297 |
+
nyquist = sample_rate / 2
|
| 298 |
+
|
| 299 |
+
# Bandpass filter
|
| 300 |
+
low = low_freq / nyquist
|
| 301 |
+
high = min(high_freq / nyquist, 0.99)
|
| 302 |
+
|
| 303 |
+
b, a = signal.butter(4, [low, high], btype='band')
|
| 304 |
+
filtered = signal.filtfilt(b, a, audio)
|
| 305 |
+
|
| 306 |
+
# Mix with original (subtle enhancement)
|
| 307 |
+
enhanced = audio * 0.3 + filtered * 0.7
|
| 308 |
+
enhanced = enhanced / (np.max(np.abs(enhanced)) + 1e-8)
|
| 309 |
+
|
| 310 |
+
return enhanced.astype(np.float32), 0.7
|
| 311 |
+
|
| 312 |
+
def process_multi_bird(
|
| 313 |
+
self,
|
| 314 |
+
audio: np.ndarray,
|
| 315 |
+
sample_rate: int,
|
| 316 |
+
max_birds: int = 3
|
| 317 |
+
) -> List[Dict[str, Any]]:
|
| 318 |
+
"""
|
| 319 |
+
Process multi-bird recording to isolate individual birds.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
audio: Multi-bird recording
|
| 323 |
+
sample_rate: Sample rate
|
| 324 |
+
max_birds: Maximum number of birds to isolate
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
List of isolated bird calls with metadata
|
| 328 |
+
"""
|
| 329 |
+
# Create prompts for multiple birds
|
| 330 |
+
text_prompts = [f"bird call {i+1}" for i in range(max_birds)]
|
| 331 |
+
text_prompts.append("background noise")
|
| 332 |
+
|
| 333 |
+
sources = self.separate_sources(audio, sample_rate, text_prompts)
|
| 334 |
+
|
| 335 |
+
# Filter to just bird sources
|
| 336 |
+
bird_calls = []
|
| 337 |
+
for source in sources:
|
| 338 |
+
if 'bird' in source['label'].lower() and source['energy'] > self.config.min_source_energy:
|
| 339 |
+
bird_calls.append({
|
| 340 |
+
'audio': source['audio'],
|
| 341 |
+
'energy': source['energy'],
|
| 342 |
+
'index': len(bird_calls)
|
| 343 |
+
})
|
| 344 |
+
|
| 345 |
+
# Sort by energy (loudest first)
|
| 346 |
+
bird_calls.sort(key=lambda x: x['energy'], reverse=True)
|
| 347 |
+
|
| 348 |
+
return bird_calls[:max_birds]
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class SAMAudioEnhancer:
|
| 352 |
+
"""
|
| 353 |
+
High-level interface for using SAM-Audio to improve BirdSense accuracy.
|
| 354 |
+
|
| 355 |
+
Provides automatic preprocessing to:
|
| 356 |
+
1. Improve SNR for feeble recordings
|
| 357 |
+
2. Handle noisy environments
|
| 358 |
+
3. Separate multi-bird choruses
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
def __init__(self, config: Optional[SAMAudioConfig] = None):
|
| 362 |
+
self.processor = SAMAudioProcessor(config)
|
| 363 |
+
self._initialized = False
|
| 364 |
+
|
| 365 |
+
def initialize(self) -> bool:
|
| 366 |
+
"""Initialize SAM-Audio (loads model)."""
|
| 367 |
+
if not self._initialized:
|
| 368 |
+
self._initialized = self.processor.load_model()
|
| 369 |
+
return self._initialized
|
| 370 |
+
|
| 371 |
+
def enhance_audio(
|
| 372 |
+
self,
|
| 373 |
+
audio: np.ndarray,
|
| 374 |
+
sample_rate: int,
|
| 375 |
+
scenario: str = "auto"
|
| 376 |
+
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
| 377 |
+
"""
|
| 378 |
+
Automatically enhance audio for better bird recognition.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
audio: Input audio
|
| 382 |
+
sample_rate: Sample rate
|
| 383 |
+
scenario: One of 'auto', 'feeble', 'noisy', 'multi_bird'
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Tuple of (enhanced_audio, metadata)
|
| 387 |
+
"""
|
| 388 |
+
metadata = {
|
| 389 |
+
'original_rms': float(np.sqrt(np.mean(audio ** 2))),
|
| 390 |
+
'scenario': scenario,
|
| 391 |
+
'sam_audio_used': self.processor.is_available()
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
if scenario == "auto":
|
| 395 |
+
scenario = self._detect_scenario(audio, sample_rate)
|
| 396 |
+
metadata['detected_scenario'] = scenario
|
| 397 |
+
|
| 398 |
+
if scenario == "feeble":
|
| 399 |
+
enhanced, quality = self.processor.isolate_bird_call(audio, sample_rate)
|
| 400 |
+
# Boost amplitude
|
| 401 |
+
enhanced = enhanced * 2.0
|
| 402 |
+
enhanced = np.clip(enhanced, -1.0, 1.0)
|
| 403 |
+
metadata['enhancement'] = 'amplitude_boost'
|
| 404 |
+
|
| 405 |
+
elif scenario == "noisy":
|
| 406 |
+
enhanced, quality = self.processor.isolate_bird_call(audio, sample_rate)
|
| 407 |
+
metadata['enhancement'] = 'noise_removal'
|
| 408 |
+
|
| 409 |
+
elif scenario == "multi_bird":
|
| 410 |
+
birds = self.processor.process_multi_bird(audio, sample_rate)
|
| 411 |
+
if birds:
|
| 412 |
+
# Return loudest bird for primary classification
|
| 413 |
+
enhanced = birds[0]['audio']
|
| 414 |
+
metadata['num_birds_detected'] = len(birds)
|
| 415 |
+
metadata['enhancement'] = 'bird_separation'
|
| 416 |
+
else:
|
| 417 |
+
enhanced = audio
|
| 418 |
+
metadata['enhancement'] = 'none'
|
| 419 |
+
else:
|
| 420 |
+
enhanced = audio
|
| 421 |
+
metadata['enhancement'] = 'none'
|
| 422 |
+
|
| 423 |
+
metadata['enhanced_rms'] = float(np.sqrt(np.mean(enhanced ** 2)))
|
| 424 |
+
metadata['snr_improvement'] = metadata['enhanced_rms'] / (metadata['original_rms'] + 1e-8)
|
| 425 |
+
|
| 426 |
+
return enhanced.astype(np.float32), metadata
|
| 427 |
+
|
| 428 |
+
def _detect_scenario(
|
| 429 |
+
self,
|
| 430 |
+
audio: np.ndarray,
|
| 431 |
+
sample_rate: int
|
| 432 |
+
) -> str:
|
| 433 |
+
"""Automatically detect audio scenario."""
|
| 434 |
+
rms = np.sqrt(np.mean(audio ** 2))
|
| 435 |
+
|
| 436 |
+
# Check for feeble audio
|
| 437 |
+
if rms < 0.05:
|
| 438 |
+
return "feeble"
|
| 439 |
+
|
| 440 |
+
# Check for multi-source (high variance in spectral energy)
|
| 441 |
+
import scipy.signal as signal
|
| 442 |
+
f, t, Zxx = signal.stft(audio, fs=sample_rate, nperseg=512)
|
| 443 |
+
frame_energy = np.sum(np.abs(Zxx) ** 2, axis=0)
|
| 444 |
+
energy_variance = np.var(frame_energy) / (np.mean(frame_energy) ** 2 + 1e-8)
|
| 445 |
+
|
| 446 |
+
if energy_variance > 2.0:
|
| 447 |
+
return "multi_bird"
|
| 448 |
+
|
| 449 |
+
# Check SNR estimate
|
| 450 |
+
# High spectral flatness suggests noise
|
| 451 |
+
spectral_flatness = np.exp(np.mean(np.log(np.abs(Zxx) + 1e-8))) / (np.mean(np.abs(Zxx)) + 1e-8)
|
| 452 |
+
if spectral_flatness > 0.3:
|
| 453 |
+
return "noisy"
|
| 454 |
+
|
| 455 |
+
return "clear"
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# Convenience function
|
| 459 |
+
def create_sam_audio_enhancer(
|
| 460 |
+
device: str = "auto",
|
| 461 |
+
load_model: bool = True
|
| 462 |
+
) -> SAMAudioEnhancer:
|
| 463 |
+
"""
|
| 464 |
+
Create SAM-Audio enhancer instance.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
device: Compute device
|
| 468 |
+
load_model: Whether to load model immediately
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
Configured SAMAudioEnhancer
|
| 472 |
+
"""
|
| 473 |
+
config = SAMAudioConfig(device=device)
|
| 474 |
+
enhancer = SAMAudioEnhancer(config)
|
| 475 |
+
|
| 476 |
+
if load_model:
|
| 477 |
+
enhancer.initialize()
|
| 478 |
+
|
| 479 |
+
return enhancer
|
| 480 |
+
|
data/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BirdSense Data Module."""
|
| 2 |
+
|
| 3 |
+
from .species_db import IndiaSpeciesDatabase, SpeciesInfo
|
| 4 |
+
|
| 5 |
+
__all__ = ["IndiaSpeciesDatabase", "SpeciesInfo"]
|
| 6 |
+
|
data/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (318 Bytes). View file
|
|
|
data/__pycache__/species_db.cpython-314.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
data/species_db.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
India Bird Species Database for BirdSense.
|
| 3 |
+
|
| 4 |
+
Contains information about Indian bird species including:
|
| 5 |
+
- Scientific and common names
|
| 6 |
+
- Habitat information
|
| 7 |
+
- Conservation status
|
| 8 |
+
- Geographic range
|
| 9 |
+
- Vocalization descriptions
|
| 10 |
+
|
| 11 |
+
Primary source: India Biodiversity Portal, eBird, IUCN
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from typing import List, Dict, Optional
|
| 16 |
+
import json
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SpeciesInfo:
|
| 21 |
+
"""Information about a bird species."""
|
| 22 |
+
id: int
|
| 23 |
+
scientific_name: str
|
| 24 |
+
common_name: str
|
| 25 |
+
hindi_name: Optional[str] = None
|
| 26 |
+
family: str = ""
|
| 27 |
+
order: str = ""
|
| 28 |
+
|
| 29 |
+
# Status
|
| 30 |
+
conservation_status: str = "LC" # LC, NT, VU, EN, CR
|
| 31 |
+
endemic_to_india: bool = False
|
| 32 |
+
migratory_status: str = "Resident" # Resident, Winter Visitor, Summer Visitor, Passage Migrant
|
| 33 |
+
|
| 34 |
+
# Habitat
|
| 35 |
+
habitats: List[str] = field(default_factory=list)
|
| 36 |
+
elevation_min: int = 0 # meters
|
| 37 |
+
elevation_max: int = 5000
|
| 38 |
+
|
| 39 |
+
# Range
|
| 40 |
+
states: List[str] = field(default_factory=list)
|
| 41 |
+
range_description: str = ""
|
| 42 |
+
|
| 43 |
+
# Vocalization
|
| 44 |
+
call_description: str = ""
|
| 45 |
+
song_description: str = ""
|
| 46 |
+
call_frequency_range: tuple = (0, 10000) # Hz
|
| 47 |
+
|
| 48 |
+
# For model
|
| 49 |
+
class_index: int = 0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class IndiaSpeciesDatabase:
|
| 53 |
+
"""
|
| 54 |
+
Database of Indian bird species.
|
| 55 |
+
|
| 56 |
+
Provides species information for:
|
| 57 |
+
- Model training (class labels)
|
| 58 |
+
- LLM reasoning (species context)
|
| 59 |
+
- Novelty detection (range checking)
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self):
|
| 63 |
+
self.species: Dict[int, SpeciesInfo] = {}
|
| 64 |
+
self.name_to_id: Dict[str, int] = {}
|
| 65 |
+
self._init_species()
|
| 66 |
+
|
| 67 |
+
def _init_species(self):
|
| 68 |
+
"""Initialize with common Indian bird species."""
|
| 69 |
+
# This is a representative sample - full database would have 1300+ species
|
| 70 |
+
species_data = [
|
| 71 |
+
# Cuckoos
|
| 72 |
+
SpeciesInfo(
|
| 73 |
+
id=0,
|
| 74 |
+
scientific_name="Cuculus micropterus",
|
| 75 |
+
common_name="Indian Cuckoo",
|
| 76 |
+
hindi_name="कोयल",
|
| 77 |
+
family="Cuculidae",
|
| 78 |
+
order="Cuculiformes",
|
| 79 |
+
conservation_status="LC",
|
| 80 |
+
migratory_status="Summer Visitor",
|
| 81 |
+
habitats=["Forest", "Woodland"],
|
| 82 |
+
elevation_min=0, elevation_max=3000,
|
| 83 |
+
states=["All India"],
|
| 84 |
+
call_description="Four-note whistle 'cross-word puzzle' or 'one more bottle'",
|
| 85 |
+
call_frequency_range=(1000, 3000),
|
| 86 |
+
class_index=0
|
| 87 |
+
),
|
| 88 |
+
SpeciesInfo(
|
| 89 |
+
id=1,
|
| 90 |
+
scientific_name="Eudynamys scolopaceus",
|
| 91 |
+
common_name="Asian Koel",
|
| 92 |
+
hindi_name="कोयल",
|
| 93 |
+
family="Cuculidae",
|
| 94 |
+
order="Cuculiformes",
|
| 95 |
+
conservation_status="LC",
|
| 96 |
+
migratory_status="Resident",
|
| 97 |
+
habitats=["Forest", "Urban", "Garden"],
|
| 98 |
+
elevation_min=0, elevation_max=1800,
|
| 99 |
+
states=["All India"],
|
| 100 |
+
call_description="Loud 'kuil-kuil-kuil' rising whistle, very distinctive",
|
| 101 |
+
call_frequency_range=(500, 4000),
|
| 102 |
+
class_index=1
|
| 103 |
+
),
|
| 104 |
+
|
| 105 |
+
# Robins and Thrushes
|
| 106 |
+
SpeciesInfo(
|
| 107 |
+
id=2,
|
| 108 |
+
scientific_name="Copsychus saularis",
|
| 109 |
+
common_name="Oriental Magpie-Robin",
|
| 110 |
+
hindi_name="दहियर",
|
| 111 |
+
family="Muscicapidae",
|
| 112 |
+
order="Passeriformes",
|
| 113 |
+
conservation_status="LC",
|
| 114 |
+
migratory_status="Resident",
|
| 115 |
+
habitats=["Garden", "Forest edge", "Urban"],
|
| 116 |
+
elevation_min=0, elevation_max=2000,
|
| 117 |
+
states=["All India"],
|
| 118 |
+
call_description="Rich varied song with whistles and mimicry",
|
| 119 |
+
call_frequency_range=(1500, 5000),
|
| 120 |
+
class_index=2
|
| 121 |
+
),
|
| 122 |
+
SpeciesInfo(
|
| 123 |
+
id=3,
|
| 124 |
+
scientific_name="Saxicoloides fulicatus",
|
| 125 |
+
common_name="Indian Robin",
|
| 126 |
+
hindi_name="काली चिड़ी",
|
| 127 |
+
family="Muscicapidae",
|
| 128 |
+
order="Passeriformes",
|
| 129 |
+
conservation_status="LC",
|
| 130 |
+
migratory_status="Resident",
|
| 131 |
+
endemic_to_india=True,
|
| 132 |
+
habitats=["Scrub", "Garden", "Rocky areas"],
|
| 133 |
+
elevation_min=0, elevation_max=1500,
|
| 134 |
+
states=["Peninsular India"],
|
| 135 |
+
call_description="Pleasant whistling song, alarm 'chip-chip'",
|
| 136 |
+
call_frequency_range=(2000, 6000),
|
| 137 |
+
class_index=3
|
| 138 |
+
),
|
| 139 |
+
|
| 140 |
+
# Kingfishers
|
| 141 |
+
SpeciesInfo(
|
| 142 |
+
id=4,
|
| 143 |
+
scientific_name="Alcedo atthis",
|
| 144 |
+
common_name="Common Kingfisher",
|
| 145 |
+
hindi_name="छोटा किलकिला",
|
| 146 |
+
family="Alcedinidae",
|
| 147 |
+
order="Coraciiformes",
|
| 148 |
+
conservation_status="LC",
|
| 149 |
+
migratory_status="Resident",
|
| 150 |
+
habitats=["Wetland", "River", "Stream"],
|
| 151 |
+
elevation_min=0, elevation_max=2000,
|
| 152 |
+
states=["All India"],
|
| 153 |
+
call_description="Sharp high-pitched 'chee' or 'kik-kik'",
|
| 154 |
+
call_frequency_range=(4000, 8000),
|
| 155 |
+
class_index=4
|
| 156 |
+
),
|
| 157 |
+
SpeciesInfo(
|
| 158 |
+
id=5,
|
| 159 |
+
scientific_name="Halcyon smyrnensis",
|
| 160 |
+
common_name="White-throated Kingfisher",
|
| 161 |
+
hindi_name="किलकिला",
|
| 162 |
+
family="Alcedinidae",
|
| 163 |
+
order="Coraciiformes",
|
| 164 |
+
conservation_status="LC",
|
| 165 |
+
migratory_status="Resident",
|
| 166 |
+
habitats=["Open country", "Wetland", "Garden"],
|
| 167 |
+
elevation_min=0, elevation_max=2000,
|
| 168 |
+
states=["All India"],
|
| 169 |
+
call_description="Loud laughing 'ki-ki-ki-ki' call",
|
| 170 |
+
call_frequency_range=(2000, 6000),
|
| 171 |
+
class_index=5
|
| 172 |
+
),
|
| 173 |
+
|
| 174 |
+
# Galliformes
|
| 175 |
+
SpeciesInfo(
|
| 176 |
+
id=6,
|
| 177 |
+
scientific_name="Pavo cristatus",
|
| 178 |
+
common_name="Indian Peafowl",
|
| 179 |
+
hindi_name="मोर",
|
| 180 |
+
family="Phasianidae",
|
| 181 |
+
order="Galliformes",
|
| 182 |
+
conservation_status="LC",
|
| 183 |
+
migratory_status="Resident",
|
| 184 |
+
endemic_to_india=True,
|
| 185 |
+
habitats=["Forest", "Scrub", "Cultivation"],
|
| 186 |
+
elevation_min=0, elevation_max=2000,
|
| 187 |
+
states=["All India"],
|
| 188 |
+
call_description="Loud 'may-awe' call, especially during monsoon",
|
| 189 |
+
call_frequency_range=(500, 2000),
|
| 190 |
+
class_index=6
|
| 191 |
+
),
|
| 192 |
+
SpeciesInfo(
|
| 193 |
+
id=7,
|
| 194 |
+
scientific_name="Gallus gallus",
|
| 195 |
+
common_name="Red Junglefowl",
|
| 196 |
+
hindi_name="जंगली मुर्गा",
|
| 197 |
+
family="Phasianidae",
|
| 198 |
+
order="Galliformes",
|
| 199 |
+
conservation_status="LC",
|
| 200 |
+
migratory_status="Resident",
|
| 201 |
+
habitats=["Forest", "Scrub"],
|
| 202 |
+
elevation_min=0, elevation_max=2000,
|
| 203 |
+
states=["All India except desert"],
|
| 204 |
+
call_description="Crowing like domestic rooster but shorter",
|
| 205 |
+
call_frequency_range=(500, 3000),
|
| 206 |
+
class_index=7
|
| 207 |
+
),
|
| 208 |
+
|
| 209 |
+
# Common Urban Birds
|
| 210 |
+
SpeciesInfo(
|
| 211 |
+
id=8,
|
| 212 |
+
scientific_name="Passer domesticus",
|
| 213 |
+
common_name="House Sparrow",
|
| 214 |
+
hindi_name="गौरैया",
|
| 215 |
+
family="Passeridae",
|
| 216 |
+
order="Passeriformes",
|
| 217 |
+
conservation_status="LC",
|
| 218 |
+
migratory_status="Resident",
|
| 219 |
+
habitats=["Urban", "Village", "Cultivation"],
|
| 220 |
+
elevation_min=0, elevation_max=4000,
|
| 221 |
+
states=["All India"],
|
| 222 |
+
call_description="Chirping 'chip-chip' and 'cheep' calls",
|
| 223 |
+
call_frequency_range=(2000, 6000),
|
| 224 |
+
class_index=8
|
| 225 |
+
),
|
| 226 |
+
SpeciesInfo(
|
| 227 |
+
id=9,
|
| 228 |
+
scientific_name="Acridotheres tristis",
|
| 229 |
+
common_name="Common Myna",
|
| 230 |
+
hindi_name="मैना",
|
| 231 |
+
family="Sturnidae",
|
| 232 |
+
order="Passeriformes",
|
| 233 |
+
conservation_status="LC",
|
| 234 |
+
migratory_status="Resident",
|
| 235 |
+
habitats=["Urban", "Open country", "Cultivation"],
|
| 236 |
+
elevation_min=0, elevation_max=3000,
|
| 237 |
+
states=["All India"],
|
| 238 |
+
call_description="Loud varied calls, harsh 'krrr', whistles",
|
| 239 |
+
call_frequency_range=(1000, 5000),
|
| 240 |
+
class_index=9
|
| 241 |
+
),
|
| 242 |
+
|
| 243 |
+
# Barbets
|
| 244 |
+
SpeciesInfo(
|
| 245 |
+
id=10,
|
| 246 |
+
scientific_name="Psilopogon haemacephalus",
|
| 247 |
+
common_name="Coppersmith Barbet",
|
| 248 |
+
hindi_name="छोटा बसंता",
|
| 249 |
+
family="Megalaimidae",
|
| 250 |
+
order="Piciformes",
|
| 251 |
+
conservation_status="LC",
|
| 252 |
+
migratory_status="Resident",
|
| 253 |
+
habitats=["Garden", "Forest", "Urban"],
|
| 254 |
+
elevation_min=0, elevation_max=1500,
|
| 255 |
+
states=["All India"],
|
| 256 |
+
call_description="Monotonous 'tuk-tuk-tuk' like hammer on metal",
|
| 257 |
+
call_frequency_range=(1500, 3000),
|
| 258 |
+
class_index=10
|
| 259 |
+
),
|
| 260 |
+
SpeciesInfo(
|
| 261 |
+
id=11,
|
| 262 |
+
scientific_name="Psilopogon zeylanicus",
|
| 263 |
+
common_name="Brown-headed Barbet",
|
| 264 |
+
hindi_name="बड़ा बसंता",
|
| 265 |
+
family="Megalaimidae",
|
| 266 |
+
order="Piciformes",
|
| 267 |
+
conservation_status="LC",
|
| 268 |
+
migratory_status="Resident",
|
| 269 |
+
habitats=["Forest", "Garden"],
|
| 270 |
+
elevation_min=0, elevation_max=2000,
|
| 271 |
+
states=["Peninsular India"],
|
| 272 |
+
call_description="Loud 'kutroo-kutroo' repeated",
|
| 273 |
+
call_frequency_range=(1000, 3000),
|
| 274 |
+
class_index=11
|
| 275 |
+
),
|
| 276 |
+
|
| 277 |
+
# Parakeets
|
| 278 |
+
SpeciesInfo(
|
| 279 |
+
id=12,
|
| 280 |
+
scientific_name="Psittacula krameri",
|
| 281 |
+
common_name="Rose-ringed Parakeet",
|
| 282 |
+
hindi_name="तोता",
|
| 283 |
+
family="Psittacidae",
|
| 284 |
+
order="Psittaciformes",
|
| 285 |
+
conservation_status="LC",
|
| 286 |
+
migratory_status="Resident",
|
| 287 |
+
habitats=["Urban", "Cultivation", "Forest"],
|
| 288 |
+
elevation_min=0, elevation_max=2000,
|
| 289 |
+
states=["All India"],
|
| 290 |
+
call_description="Loud screeching 'kee-ak' in flight",
|
| 291 |
+
call_frequency_range=(2000, 5000),
|
| 292 |
+
class_index=12
|
| 293 |
+
),
|
| 294 |
+
|
| 295 |
+
# Doves
|
| 296 |
+
SpeciesInfo(
|
| 297 |
+
id=13,
|
| 298 |
+
scientific_name="Streptopelia chinensis",
|
| 299 |
+
common_name="Spotted Dove",
|
| 300 |
+
hindi_name="चित्रोक फाख्ता",
|
| 301 |
+
family="Columbidae",
|
| 302 |
+
order="Columbiformes",
|
| 303 |
+
conservation_status="LC",
|
| 304 |
+
migratory_status="Resident",
|
| 305 |
+
habitats=["Garden", "Cultivation", "Forest edge"],
|
| 306 |
+
elevation_min=0, elevation_max=3000,
|
| 307 |
+
states=["All India"],
|
| 308 |
+
call_description="Soft cooing 'coo-coo-coo'",
|
| 309 |
+
call_frequency_range=(300, 1500),
|
| 310 |
+
class_index=13
|
| 311 |
+
),
|
| 312 |
+
SpeciesInfo(
|
| 313 |
+
id=14,
|
| 314 |
+
scientific_name="Streptopelia decaocto",
|
| 315 |
+
common_name="Eurasian Collared Dove",
|
| 316 |
+
hindi_name="धूसर फाख्ता",
|
| 317 |
+
family="Columbidae",
|
| 318 |
+
order="Columbiformes",
|
| 319 |
+
conservation_status="LC",
|
| 320 |
+
migratory_status="Resident",
|
| 321 |
+
habitats=["Urban", "Cultivation"],
|
| 322 |
+
elevation_min=0, elevation_max=2500,
|
| 323 |
+
states=["All India"],
|
| 324 |
+
call_description="Three-note 'coo-COO-coo' with emphasis on middle",
|
| 325 |
+
call_frequency_range=(400, 1200),
|
| 326 |
+
class_index=14
|
| 327 |
+
),
|
| 328 |
+
|
| 329 |
+
# Bulbuls
|
| 330 |
+
SpeciesInfo(
|
| 331 |
+
id=15,
|
| 332 |
+
scientific_name="Pycnonotus cafer",
|
| 333 |
+
common_name="Red-vented Bulbul",
|
| 334 |
+
hindi_name="बुलबुल",
|
| 335 |
+
family="Pycnonotidae",
|
| 336 |
+
order="Passeriformes",
|
| 337 |
+
conservation_status="LC",
|
| 338 |
+
migratory_status="Resident",
|
| 339 |
+
habitats=["Garden", "Scrub", "Forest edge"],
|
| 340 |
+
elevation_min=0, elevation_max=2500,
|
| 341 |
+
states=["All India"],
|
| 342 |
+
call_description="Cheerful 'be-care-ful' and chattering",
|
| 343 |
+
call_frequency_range=(1500, 5000),
|
| 344 |
+
class_index=15
|
| 345 |
+
),
|
| 346 |
+
SpeciesInfo(
|
| 347 |
+
id=16,
|
| 348 |
+
scientific_name="Pycnonotus jocosus",
|
| 349 |
+
common_name="Red-whiskered Bulbul",
|
| 350 |
+
hindi_name="सिपाही बुलबुल",
|
| 351 |
+
family="Pycnonotidae",
|
| 352 |
+
order="Passeriformes",
|
| 353 |
+
conservation_status="LC",
|
| 354 |
+
migratory_status="Resident",
|
| 355 |
+
habitats=["Garden", "Forest edge", "Hill forest"],
|
| 356 |
+
elevation_min=0, elevation_max=2500,
|
| 357 |
+
states=["Peninsular India", "Himalayan foothills"],
|
| 358 |
+
call_description="Pleasant whistles, 'kick-pettigrew'",
|
| 359 |
+
call_frequency_range=(2000, 6000),
|
| 360 |
+
class_index=16
|
| 361 |
+
),
|
| 362 |
+
|
| 363 |
+
# Sunbirds
|
| 364 |
+
SpeciesInfo(
|
| 365 |
+
id=17,
|
| 366 |
+
scientific_name="Cinnyris asiaticus",
|
| 367 |
+
common_name="Purple Sunbird",
|
| 368 |
+
hindi_name="शक्कर खोरा",
|
| 369 |
+
family="Nectariniidae",
|
| 370 |
+
order="Passeriformes",
|
| 371 |
+
conservation_status="LC",
|
| 372 |
+
migratory_status="Resident",
|
| 373 |
+
habitats=["Garden", "Scrub", "Forest edge"],
|
| 374 |
+
elevation_min=0, elevation_max=2500,
|
| 375 |
+
states=["All India"],
|
| 376 |
+
call_description="Sharp 'chwit' and fast trilling song",
|
| 377 |
+
call_frequency_range=(3000, 8000),
|
| 378 |
+
class_index=17
|
| 379 |
+
),
|
| 380 |
+
|
| 381 |
+
# Tailorbird
|
| 382 |
+
SpeciesInfo(
|
| 383 |
+
id=18,
|
| 384 |
+
scientific_name="Orthotomus sutorius",
|
| 385 |
+
common_name="Common Tailorbird",
|
| 386 |
+
hindi_name="दर्जी चिड़िया",
|
| 387 |
+
family="Cisticolidae",
|
| 388 |
+
order="Passeriformes",
|
| 389 |
+
conservation_status="LC",
|
| 390 |
+
migratory_status="Resident",
|
| 391 |
+
habitats=["Garden", "Scrub", "Forest undergrowth"],
|
| 392 |
+
elevation_min=0, elevation_max=2000,
|
| 393 |
+
states=["All India"],
|
| 394 |
+
call_description="Loud 'towit-towit-towit' repeated",
|
| 395 |
+
call_frequency_range=(3000, 6000),
|
| 396 |
+
class_index=18
|
| 397 |
+
),
|
| 398 |
+
|
| 399 |
+
# Owls
|
| 400 |
+
SpeciesInfo(
|
| 401 |
+
id=19,
|
| 402 |
+
scientific_name="Athene brama",
|
| 403 |
+
common_name="Spotted Owlet",
|
| 404 |
+
hindi_name="खूसट",
|
| 405 |
+
family="Strigidae",
|
| 406 |
+
order="Strigiformes",
|
| 407 |
+
conservation_status="LC",
|
| 408 |
+
migratory_status="Resident",
|
| 409 |
+
habitats=["Open country", "Cultivation", "Urban"],
|
| 410 |
+
elevation_min=0, elevation_max=1500,
|
| 411 |
+
states=["All India except dense forest"],
|
| 412 |
+
call_description="Harsh chattering 'chirurr-chirurr'",
|
| 413 |
+
call_frequency_range=(1000, 4000),
|
| 414 |
+
class_index=19
|
| 415 |
+
),
|
| 416 |
+
|
| 417 |
+
# Adding more diverse species for robust testing
|
| 418 |
+
SpeciesInfo(
|
| 419 |
+
id=20,
|
| 420 |
+
scientific_name="Corvus splendens",
|
| 421 |
+
common_name="House Crow",
|
| 422 |
+
hindi_name="कौआ",
|
| 423 |
+
family="Corvidae",
|
| 424 |
+
order="Passeriformes",
|
| 425 |
+
conservation_status="LC",
|
| 426 |
+
migratory_status="Resident",
|
| 427 |
+
habitats=["Urban", "Village"],
|
| 428 |
+
elevation_min=0, elevation_max=2000,
|
| 429 |
+
states=["All India"],
|
| 430 |
+
call_description="Harsh 'kaa-kaa' cawing",
|
| 431 |
+
call_frequency_range=(800, 2500),
|
| 432 |
+
class_index=20
|
| 433 |
+
),
|
| 434 |
+
SpeciesInfo(
|
| 435 |
+
id=21,
|
| 436 |
+
scientific_name="Dicrurus macrocercus",
|
| 437 |
+
common_name="Black Drongo",
|
| 438 |
+
hindi_name="कोतवाल",
|
| 439 |
+
family="Dicruridae",
|
| 440 |
+
order="Passeriformes",
|
| 441 |
+
conservation_status="LC",
|
| 442 |
+
migratory_status="Resident",
|
| 443 |
+
habitats=["Open country", "Cultivation"],
|
| 444 |
+
elevation_min=0, elevation_max=2000,
|
| 445 |
+
states=["All India"],
|
| 446 |
+
call_description="Varied metallic calls and mimicry",
|
| 447 |
+
call_frequency_range=(2000, 6000),
|
| 448 |
+
class_index=21
|
| 449 |
+
),
|
| 450 |
+
SpeciesInfo(
|
| 451 |
+
id=22,
|
| 452 |
+
scientific_name="Oriolus kundoo",
|
| 453 |
+
common_name="Indian Golden Oriole",
|
| 454 |
+
hindi_name="पीलक",
|
| 455 |
+
family="Oriolidae",
|
| 456 |
+
order="Passeriformes",
|
| 457 |
+
conservation_status="LC",
|
| 458 |
+
migratory_status="Summer Visitor",
|
| 459 |
+
habitats=["Forest", "Garden", "Mango groves"],
|
| 460 |
+
elevation_min=0, elevation_max=2500,
|
| 461 |
+
states=["All India"],
|
| 462 |
+
call_description="Fluty 'pee-lo' whistle",
|
| 463 |
+
call_frequency_range=(1500, 4000),
|
| 464 |
+
class_index=22
|
| 465 |
+
),
|
| 466 |
+
SpeciesInfo(
|
| 467 |
+
id=23,
|
| 468 |
+
scientific_name="Upupa epops",
|
| 469 |
+
common_name="Common Hoopoe",
|
| 470 |
+
hindi_name="हुदहुद",
|
| 471 |
+
family="Upupidae",
|
| 472 |
+
order="Bucerotiformes",
|
| 473 |
+
conservation_status="LC",
|
| 474 |
+
migratory_status="Resident",
|
| 475 |
+
habitats=["Open country", "Cultivation", "Garden"],
|
| 476 |
+
elevation_min=0, elevation_max=3000,
|
| 477 |
+
states=["All India"],
|
| 478 |
+
call_description="Soft 'hoo-po-po' or 'oop-oop-oop'",
|
| 479 |
+
call_frequency_range=(500, 2000),
|
| 480 |
+
class_index=23
|
| 481 |
+
),
|
| 482 |
+
SpeciesInfo(
|
| 483 |
+
id=24,
|
| 484 |
+
scientific_name="Merops orientalis",
|
| 485 |
+
common_name="Green Bee-eater",
|
| 486 |
+
hindi_name="हरियल पतरंगा",
|
| 487 |
+
family="Meropidae",
|
| 488 |
+
order="Coraciiformes",
|
| 489 |
+
conservation_status="LC",
|
| 490 |
+
migratory_status="Resident",
|
| 491 |
+
habitats=["Open country", "Cultivation"],
|
| 492 |
+
elevation_min=0, elevation_max=2000,
|
| 493 |
+
states=["All India"],
|
| 494 |
+
call_description="Soft trilling 'tree-tree-tree'",
|
| 495 |
+
call_frequency_range=(3000, 7000),
|
| 496 |
+
class_index=24
|
| 497 |
+
),
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
for species in species_data:
|
| 501 |
+
self.species[species.id] = species
|
| 502 |
+
self.name_to_id[species.common_name.lower()] = species.id
|
| 503 |
+
self.name_to_id[species.scientific_name.lower()] = species.id
|
| 504 |
+
if species.hindi_name:
|
| 505 |
+
self.name_to_id[species.hindi_name] = species.id
|
| 506 |
+
|
| 507 |
+
def get_species(self, species_id: int) -> Optional[SpeciesInfo]:
|
| 508 |
+
"""Get species by ID."""
|
| 509 |
+
return self.species.get(species_id)
|
| 510 |
+
|
| 511 |
+
def get_by_name(self, name: str) -> Optional[SpeciesInfo]:
|
| 512 |
+
"""Get species by common or scientific name."""
|
| 513 |
+
species_id = self.name_to_id.get(name.lower())
|
| 514 |
+
if species_id is not None:
|
| 515 |
+
return self.species.get(species_id)
|
| 516 |
+
return None
|
| 517 |
+
|
| 518 |
+
def get_all_species(self) -> List[SpeciesInfo]:
|
| 519 |
+
"""Get all species."""
|
| 520 |
+
return list(self.species.values())
|
| 521 |
+
|
| 522 |
+
def get_species_names(self) -> List[str]:
|
| 523 |
+
"""Get list of all common names in order of class index."""
|
| 524 |
+
sorted_species = sorted(self.species.values(), key=lambda s: s.class_index)
|
| 525 |
+
return [s.common_name for s in sorted_species]
|
| 526 |
+
|
| 527 |
+
def get_num_classes(self) -> int:
|
| 528 |
+
"""Get number of species classes."""
|
| 529 |
+
return len(self.species)
|
| 530 |
+
|
| 531 |
+
def get_endemic_species(self) -> List[SpeciesInfo]:
|
| 532 |
+
"""Get species endemic to India."""
|
| 533 |
+
return [s for s in self.species.values() if s.endemic_to_india]
|
| 534 |
+
|
| 535 |
+
def get_conservation_priority(self, status: str = "VU") -> List[SpeciesInfo]:
|
| 536 |
+
"""Get species with conservation status at or above specified level."""
|
| 537 |
+
priority_order = {"LC": 0, "NT": 1, "VU": 2, "EN": 3, "CR": 4}
|
| 538 |
+
threshold = priority_order.get(status, 2)
|
| 539 |
+
return [s for s in self.species.values()
|
| 540 |
+
if priority_order.get(s.conservation_status, 0) >= threshold]
|
| 541 |
+
|
| 542 |
+
def get_species_for_llm_context(self, species_id: int) -> str:
|
| 543 |
+
"""Get formatted species information for LLM reasoning."""
|
| 544 |
+
species = self.get_species(species_id)
|
| 545 |
+
if not species:
|
| 546 |
+
return "Species not found."
|
| 547 |
+
|
| 548 |
+
return f"""
|
| 549 |
+
Species: {species.common_name} ({species.scientific_name})
|
| 550 |
+
Hindi Name: {species.hindi_name or 'N/A'}
|
| 551 |
+
Family: {species.family}
|
| 552 |
+
Conservation Status: {species.conservation_status}
|
| 553 |
+
Migratory Status: {species.migratory_status}
|
| 554 |
+
Endemic to India: {'Yes' if species.endemic_to_india else 'No'}
|
| 555 |
+
Habitats: {', '.join(species.habitats)}
|
| 556 |
+
Elevation Range: {species.elevation_min}m - {species.elevation_max}m
|
| 557 |
+
Distribution: {', '.join(species.states)}
|
| 558 |
+
Call Description: {species.call_description}
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
def search_by_habitat(self, habitat: str) -> List[SpeciesInfo]:
|
| 562 |
+
"""Find species by habitat type."""
|
| 563 |
+
habitat_lower = habitat.lower()
|
| 564 |
+
return [s for s in self.species.values()
|
| 565 |
+
if any(habitat_lower in h.lower() for h in s.habitats)]
|
| 566 |
+
|
| 567 |
+
def to_json(self) -> str:
|
| 568 |
+
"""Export database to JSON."""
|
| 569 |
+
data = {s.id: {
|
| 570 |
+
"scientific_name": s.scientific_name,
|
| 571 |
+
"common_name": s.common_name,
|
| 572 |
+
"hindi_name": s.hindi_name,
|
| 573 |
+
"family": s.family,
|
| 574 |
+
"conservation_status": s.conservation_status,
|
| 575 |
+
"endemic_to_india": s.endemic_to_india,
|
| 576 |
+
"migratory_status": s.migratory_status,
|
| 577 |
+
"habitats": s.habitats,
|
| 578 |
+
"call_description": s.call_description,
|
| 579 |
+
"class_index": s.class_index
|
| 580 |
+
} for s in self.species.values()}
|
| 581 |
+
return json.dumps(data, indent=2)
|
| 582 |
+
|
llm/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BirdSense LLM Module."""
|
| 2 |
+
|
| 3 |
+
from .ollama_client import OllamaClient
|
| 4 |
+
from .reasoning import BirdReasoningEngine
|
| 5 |
+
|
| 6 |
+
__all__ = ["OllamaClient", "BirdReasoningEngine"]
|
| 7 |
+
|
llm/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (344 Bytes). View file
|
|
|
llm/__pycache__/ollama_client.cpython-314.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
llm/__pycache__/reasoning.cpython-314.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
llm/__pycache__/zero_shot_identifier.cpython-314.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
llm/ollama_client.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ollama Client for BirdSense.
|
| 3 |
+
|
| 4 |
+
Provides interface to local LLM models via Ollama for:
|
| 5 |
+
- Species reasoning and verification
|
| 6 |
+
- Description matching
|
| 7 |
+
- Natural language queries about birds
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
import json
|
| 12 |
+
from typing import Optional, Dict, Any, List, AsyncGenerator
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import asyncio
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class OllamaConfig:
|
| 19 |
+
"""Configuration for Ollama client."""
|
| 20 |
+
base_url: str = "http://localhost:11434"
|
| 21 |
+
model: str = "phi3:mini" # Lightweight model for edge deployment
|
| 22 |
+
temperature: float = 0.3
|
| 23 |
+
max_tokens: int = 512
|
| 24 |
+
timeout: int = 30
|
| 25 |
+
stream: bool = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OllamaClient:
|
| 29 |
+
"""
|
| 30 |
+
Async client for Ollama API.
|
| 31 |
+
|
| 32 |
+
Supports:
|
| 33 |
+
- Text generation
|
| 34 |
+
- Streaming responses
|
| 35 |
+
- Model listing and management
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: Optional[OllamaConfig] = None):
|
| 39 |
+
self.config = config or OllamaConfig()
|
| 40 |
+
self._client: Optional[httpx.AsyncClient] = None
|
| 41 |
+
|
| 42 |
+
async def __aenter__(self):
|
| 43 |
+
self._client = httpx.AsyncClient(
|
| 44 |
+
base_url=self.config.base_url,
|
| 45 |
+
timeout=httpx.Timeout(self.config.timeout)
|
| 46 |
+
)
|
| 47 |
+
return self
|
| 48 |
+
|
| 49 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 50 |
+
if self._client:
|
| 51 |
+
await self._client.aclose()
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def client(self) -> httpx.AsyncClient:
|
| 55 |
+
if self._client is None:
|
| 56 |
+
self._client = httpx.AsyncClient(
|
| 57 |
+
base_url=self.config.base_url,
|
| 58 |
+
timeout=httpx.Timeout(self.config.timeout)
|
| 59 |
+
)
|
| 60 |
+
return self._client
|
| 61 |
+
|
| 62 |
+
async def generate(
|
| 63 |
+
self,
|
| 64 |
+
prompt: str,
|
| 65 |
+
system_prompt: Optional[str] = None,
|
| 66 |
+
temperature: Optional[float] = None,
|
| 67 |
+
max_tokens: Optional[int] = None,
|
| 68 |
+
model: Optional[str] = None
|
| 69 |
+
) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Generate text completion.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
prompt: User prompt
|
| 75 |
+
system_prompt: System instruction
|
| 76 |
+
temperature: Sampling temperature (default from config)
|
| 77 |
+
max_tokens: Max tokens to generate
|
| 78 |
+
model: Model to use (default from config)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Generated text response
|
| 82 |
+
"""
|
| 83 |
+
payload = {
|
| 84 |
+
"model": model or self.config.model,
|
| 85 |
+
"prompt": prompt,
|
| 86 |
+
"stream": False,
|
| 87 |
+
"options": {
|
| 88 |
+
"temperature": temperature or self.config.temperature,
|
| 89 |
+
"num_predict": max_tokens or self.config.max_tokens
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
if system_prompt:
|
| 94 |
+
payload["system"] = system_prompt
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
response = await self.client.post("/api/generate", json=payload)
|
| 98 |
+
response.raise_for_status()
|
| 99 |
+
result = response.json()
|
| 100 |
+
return result.get("response", "")
|
| 101 |
+
except httpx.HTTPError as e:
|
| 102 |
+
raise ConnectionError(f"Failed to connect to Ollama: {e}")
|
| 103 |
+
|
| 104 |
+
async def generate_stream(
|
| 105 |
+
self,
|
| 106 |
+
prompt: str,
|
| 107 |
+
system_prompt: Optional[str] = None,
|
| 108 |
+
model: Optional[str] = None
|
| 109 |
+
) -> AsyncGenerator[str, None]:
|
| 110 |
+
"""
|
| 111 |
+
Stream text generation.
|
| 112 |
+
|
| 113 |
+
Yields:
|
| 114 |
+
Chunks of generated text
|
| 115 |
+
"""
|
| 116 |
+
payload = {
|
| 117 |
+
"model": model or self.config.model,
|
| 118 |
+
"prompt": prompt,
|
| 119 |
+
"stream": True,
|
| 120 |
+
"options": {
|
| 121 |
+
"temperature": self.config.temperature,
|
| 122 |
+
"num_predict": self.config.max_tokens
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
if system_prompt:
|
| 127 |
+
payload["system"] = system_prompt
|
| 128 |
+
|
| 129 |
+
async with self.client.stream("POST", "/api/generate", json=payload) as response:
|
| 130 |
+
async for line in response.aiter_lines():
|
| 131 |
+
if line:
|
| 132 |
+
data = json.loads(line)
|
| 133 |
+
if "response" in data:
|
| 134 |
+
yield data["response"]
|
| 135 |
+
if data.get("done", False):
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
async def chat(
|
| 139 |
+
self,
|
| 140 |
+
messages: List[Dict[str, str]],
|
| 141 |
+
model: Optional[str] = None
|
| 142 |
+
) -> str:
|
| 143 |
+
"""
|
| 144 |
+
Chat completion with message history.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
messages: List of {"role": "user/assistant/system", "content": "..."}
|
| 148 |
+
model: Model to use
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Assistant response
|
| 152 |
+
"""
|
| 153 |
+
payload = {
|
| 154 |
+
"model": model or self.config.model,
|
| 155 |
+
"messages": messages,
|
| 156 |
+
"stream": False,
|
| 157 |
+
"options": {
|
| 158 |
+
"temperature": self.config.temperature,
|
| 159 |
+
"num_predict": self.config.max_tokens
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
response = await self.client.post("/api/chat", json=payload)
|
| 165 |
+
response.raise_for_status()
|
| 166 |
+
result = response.json()
|
| 167 |
+
return result.get("message", {}).get("content", "")
|
| 168 |
+
except httpx.HTTPError as e:
|
| 169 |
+
raise ConnectionError(f"Failed to connect to Ollama: {e}")
|
| 170 |
+
|
| 171 |
+
async def list_models(self) -> List[Dict[str, Any]]:
|
| 172 |
+
"""List available models."""
|
| 173 |
+
try:
|
| 174 |
+
response = await self.client.get("/api/tags")
|
| 175 |
+
response.raise_for_status()
|
| 176 |
+
return response.json().get("models", [])
|
| 177 |
+
except httpx.HTTPError as e:
|
| 178 |
+
raise ConnectionError(f"Failed to list models: {e}")
|
| 179 |
+
|
| 180 |
+
async def is_model_available(self, model: Optional[str] = None) -> bool:
|
| 181 |
+
"""Check if specified model is available."""
|
| 182 |
+
model = model or self.config.model
|
| 183 |
+
try:
|
| 184 |
+
models = await self.list_models()
|
| 185 |
+
return any(m.get("name", "").startswith(model.split(":")[0]) for m in models)
|
| 186 |
+
except Exception:
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
async def health_check(self) -> bool:
|
| 190 |
+
"""Check if Ollama server is running."""
|
| 191 |
+
try:
|
| 192 |
+
response = await self.client.get("/api/tags")
|
| 193 |
+
return response.status_code == 200
|
| 194 |
+
except Exception:
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class SyncOllamaClient:
|
| 199 |
+
"""
|
| 200 |
+
Synchronous wrapper for OllamaClient.
|
| 201 |
+
|
| 202 |
+
Convenience class for non-async code paths.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, config: Optional[OllamaConfig] = None):
|
| 206 |
+
self.config = config or OllamaConfig()
|
| 207 |
+
self._async_client = OllamaClient(config)
|
| 208 |
+
|
| 209 |
+
def _run(self, coro):
|
| 210 |
+
"""Run async coroutine synchronously."""
|
| 211 |
+
try:
|
| 212 |
+
loop = asyncio.get_event_loop()
|
| 213 |
+
if loop.is_running():
|
| 214 |
+
# If we're in an async context, use nest_asyncio pattern
|
| 215 |
+
import nest_asyncio
|
| 216 |
+
nest_asyncio.apply()
|
| 217 |
+
return loop.run_until_complete(coro)
|
| 218 |
+
else:
|
| 219 |
+
return loop.run_until_complete(coro)
|
| 220 |
+
except RuntimeError:
|
| 221 |
+
# No event loop exists
|
| 222 |
+
return asyncio.run(coro)
|
| 223 |
+
|
| 224 |
+
def generate(
|
| 225 |
+
self,
|
| 226 |
+
prompt: str,
|
| 227 |
+
system_prompt: Optional[str] = None,
|
| 228 |
+
temperature: Optional[float] = None,
|
| 229 |
+
max_tokens: Optional[int] = None,
|
| 230 |
+
model: Optional[str] = None
|
| 231 |
+
) -> str:
|
| 232 |
+
"""Generate text completion synchronously."""
|
| 233 |
+
return self._run(
|
| 234 |
+
self._async_client.generate(
|
| 235 |
+
prompt, system_prompt, temperature, max_tokens, model
|
| 236 |
+
)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def chat(
|
| 240 |
+
self,
|
| 241 |
+
messages: List[Dict[str, str]],
|
| 242 |
+
model: Optional[str] = None
|
| 243 |
+
) -> str:
|
| 244 |
+
"""Chat completion synchronously."""
|
| 245 |
+
return self._run(self._async_client.chat(messages, model))
|
| 246 |
+
|
| 247 |
+
def health_check(self) -> bool:
|
| 248 |
+
"""Check Ollama health synchronously."""
|
| 249 |
+
return self._run(self._async_client.health_check())
|
| 250 |
+
|
| 251 |
+
def is_model_available(self, model: Optional[str] = None) -> bool:
|
| 252 |
+
"""Check model availability synchronously."""
|
| 253 |
+
return self._run(self._async_client.is_model_available(model))
|
| 254 |
+
|
llm/reasoning.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bird Reasoning Engine for BirdSense.
|
| 3 |
+
|
| 4 |
+
Uses LLM to enhance bird species identification through:
|
| 5 |
+
- Multi-evidence reasoning (audio, visual, description)
|
| 6 |
+
- Habitat and range validation
|
| 7 |
+
- Confidence calibration
|
| 8 |
+
- Natural language explanation generation
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Optional, Dict, List, Tuple
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from .ollama_client import OllamaClient, OllamaConfig, SyncOllamaClient
|
| 17 |
+
from ..data.species_db import IndiaSpeciesDatabase, SpeciesInfo
|
| 18 |
+
except ImportError:
|
| 19 |
+
from llm.ollama_client import OllamaClient, OllamaConfig, SyncOllamaClient
|
| 20 |
+
from data.species_db import IndiaSpeciesDatabase, SpeciesInfo
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ReasoningContext:
|
| 25 |
+
"""Context for species reasoning."""
|
| 26 |
+
# Audio analysis results
|
| 27 |
+
audio_predictions: List[Tuple[int, float]] = None # [(species_id, confidence), ...]
|
| 28 |
+
audio_quality: str = "unknown"
|
| 29 |
+
|
| 30 |
+
# Location context
|
| 31 |
+
latitude: Optional[float] = None
|
| 32 |
+
longitude: Optional[float] = None
|
| 33 |
+
location_name: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
# Temporal context
|
| 36 |
+
month: Optional[int] = None
|
| 37 |
+
time_of_day: Optional[str] = None # morning, afternoon, evening, night
|
| 38 |
+
|
| 39 |
+
# Habitat context
|
| 40 |
+
habitat: Optional[str] = None
|
| 41 |
+
elevation: Optional[int] = None
|
| 42 |
+
|
| 43 |
+
# User description (if any)
|
| 44 |
+
user_description: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class ReasoningResult:
|
| 49 |
+
"""Result of species reasoning."""
|
| 50 |
+
species_id: int
|
| 51 |
+
species_name: str
|
| 52 |
+
confidence: float
|
| 53 |
+
reasoning: str
|
| 54 |
+
alternative_species: List[Tuple[str, float]]
|
| 55 |
+
novelty_flag: bool
|
| 56 |
+
novelty_explanation: Optional[str]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
SYSTEM_PROMPT = """You are an expert ornithologist specializing in Indian birds. Your role is to:
|
| 60 |
+
1. Analyze bird identification evidence from audio, visual, and contextual clues
|
| 61 |
+
2. Consider habitat, range, season, and time of day to validate identifications
|
| 62 |
+
3. Flag unusual or out-of-range sightings that could be scientifically significant
|
| 63 |
+
4. Provide clear, educational explanations
|
| 64 |
+
|
| 65 |
+
When analyzing bird identifications:
|
| 66 |
+
- Consider the probability of the species being present at the given location and time
|
| 67 |
+
- Note if the species is commonly confused with similar species
|
| 68 |
+
- Be aware of seasonal migration patterns
|
| 69 |
+
- Flag any sightings that would be unusual or noteworthy
|
| 70 |
+
|
| 71 |
+
Respond in a structured format with your reasoning and final assessment."""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class BirdReasoningEngine:
|
| 75 |
+
"""
|
| 76 |
+
LLM-powered reasoning engine for bird identification.
|
| 77 |
+
|
| 78 |
+
Combines audio classifier predictions with contextual information
|
| 79 |
+
to produce calibrated, explainable species identifications.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
ollama_config: Optional[OllamaConfig] = None,
|
| 85 |
+
species_db: Optional[IndiaSpeciesDatabase] = None
|
| 86 |
+
):
|
| 87 |
+
self.ollama_config = ollama_config or OllamaConfig()
|
| 88 |
+
self.species_db = species_db or IndiaSpeciesDatabase()
|
| 89 |
+
self.sync_client = SyncOllamaClient(self.ollama_config)
|
| 90 |
+
|
| 91 |
+
def _build_reasoning_prompt(
|
| 92 |
+
self,
|
| 93 |
+
context: ReasoningContext
|
| 94 |
+
) -> str:
|
| 95 |
+
"""Build prompt for species reasoning."""
|
| 96 |
+
prompt_parts = []
|
| 97 |
+
|
| 98 |
+
# Audio predictions
|
| 99 |
+
if context.audio_predictions:
|
| 100 |
+
prompt_parts.append("## Audio Analysis Results")
|
| 101 |
+
for species_id, confidence in context.audio_predictions[:5]:
|
| 102 |
+
species = self.species_db.get_species(species_id)
|
| 103 |
+
if species:
|
| 104 |
+
prompt_parts.append(
|
| 105 |
+
f"- {species.common_name} ({species.scientific_name}): "
|
| 106 |
+
f"{confidence:.1%} confidence"
|
| 107 |
+
)
|
| 108 |
+
prompt_parts.append(f" Call: {species.call_description}")
|
| 109 |
+
prompt_parts.append(f"Audio Quality: {context.audio_quality}")
|
| 110 |
+
prompt_parts.append("")
|
| 111 |
+
|
| 112 |
+
# Location context
|
| 113 |
+
if context.location_name or (context.latitude and context.longitude):
|
| 114 |
+
prompt_parts.append("## Location")
|
| 115 |
+
if context.location_name:
|
| 116 |
+
prompt_parts.append(f"- Location: {context.location_name}")
|
| 117 |
+
if context.latitude and context.longitude:
|
| 118 |
+
prompt_parts.append(f"- Coordinates: {context.latitude:.4f}°N, {context.longitude:.4f}°E")
|
| 119 |
+
if context.elevation:
|
| 120 |
+
prompt_parts.append(f"- Elevation: {context.elevation}m")
|
| 121 |
+
prompt_parts.append("")
|
| 122 |
+
|
| 123 |
+
# Temporal context
|
| 124 |
+
if context.month or context.time_of_day:
|
| 125 |
+
prompt_parts.append("## Time")
|
| 126 |
+
if context.month:
|
| 127 |
+
months = ["January", "February", "March", "April", "May", "June",
|
| 128 |
+
"July", "August", "September", "October", "November", "December"]
|
| 129 |
+
prompt_parts.append(f"- Month: {months[context.month - 1]}")
|
| 130 |
+
if context.time_of_day:
|
| 131 |
+
prompt_parts.append(f"- Time of Day: {context.time_of_day}")
|
| 132 |
+
prompt_parts.append("")
|
| 133 |
+
|
| 134 |
+
# Habitat
|
| 135 |
+
if context.habitat:
|
| 136 |
+
prompt_parts.append(f"## Habitat: {context.habitat}")
|
| 137 |
+
prompt_parts.append("")
|
| 138 |
+
|
| 139 |
+
# User description
|
| 140 |
+
if context.user_description:
|
| 141 |
+
prompt_parts.append("## Observer Description")
|
| 142 |
+
prompt_parts.append(context.user_description)
|
| 143 |
+
prompt_parts.append("")
|
| 144 |
+
|
| 145 |
+
prompt_parts.append("""## Task
|
| 146 |
+
Based on the above evidence, provide:
|
| 147 |
+
1. Your assessment of the most likely species
|
| 148 |
+
2. Confidence level (high/medium/low) with reasoning
|
| 149 |
+
3. Alternative species to consider
|
| 150 |
+
4. Whether this sighting is unusual or noteworthy for research
|
| 151 |
+
5. Any identifying features that would help confirm the identification
|
| 152 |
+
|
| 153 |
+
Format your response as:
|
| 154 |
+
ASSESSMENT: [Species name]
|
| 155 |
+
CONFIDENCE: [high/medium/low]
|
| 156 |
+
REASONING: [Your detailed reasoning]
|
| 157 |
+
ALTERNATIVES: [List of alternative species with brief notes]
|
| 158 |
+
NOTABLE: [yes/no] - [Explanation if yes]
|
| 159 |
+
""")
|
| 160 |
+
|
| 161 |
+
return "\n".join(prompt_parts)
|
| 162 |
+
|
| 163 |
+
def reason(
|
| 164 |
+
self,
|
| 165 |
+
context: ReasoningContext
|
| 166 |
+
) -> ReasoningResult:
|
| 167 |
+
"""
|
| 168 |
+
Perform species reasoning using LLM.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
context: Reasoning context with all available evidence
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
ReasoningResult with final species assessment
|
| 175 |
+
"""
|
| 176 |
+
prompt = self._build_reasoning_prompt(context)
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
response = self.sync_client.generate(
|
| 180 |
+
prompt=prompt,
|
| 181 |
+
system_prompt=SYSTEM_PROMPT
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Parse response
|
| 185 |
+
return self._parse_response(response, context)
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
# Fallback to audio-only prediction if LLM fails
|
| 189 |
+
if context.audio_predictions:
|
| 190 |
+
top_pred = context.audio_predictions[0]
|
| 191 |
+
species = self.species_db.get_species(top_pred[0])
|
| 192 |
+
return ReasoningResult(
|
| 193 |
+
species_id=top_pred[0],
|
| 194 |
+
species_name=species.common_name if species else "Unknown",
|
| 195 |
+
confidence=top_pred[1],
|
| 196 |
+
reasoning=f"LLM reasoning unavailable ({str(e)}). Using audio prediction only.",
|
| 197 |
+
alternative_species=[],
|
| 198 |
+
novelty_flag=False,
|
| 199 |
+
novelty_explanation=None
|
| 200 |
+
)
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
def _parse_response(
|
| 204 |
+
self,
|
| 205 |
+
response: str,
|
| 206 |
+
context: ReasoningContext
|
| 207 |
+
) -> ReasoningResult:
|
| 208 |
+
"""Parse LLM response into structured result."""
|
| 209 |
+
lines = response.strip().split('\n')
|
| 210 |
+
|
| 211 |
+
assessment = ""
|
| 212 |
+
confidence = 0.5
|
| 213 |
+
reasoning = ""
|
| 214 |
+
alternatives = []
|
| 215 |
+
notable = False
|
| 216 |
+
notable_explanation = None
|
| 217 |
+
|
| 218 |
+
current_section = None
|
| 219 |
+
|
| 220 |
+
for line in lines:
|
| 221 |
+
line = line.strip()
|
| 222 |
+
|
| 223 |
+
if line.startswith("ASSESSMENT:"):
|
| 224 |
+
assessment = line.split(":", 1)[1].strip()
|
| 225 |
+
current_section = "assessment"
|
| 226 |
+
elif line.startswith("CONFIDENCE:"):
|
| 227 |
+
conf_text = line.split(":", 1)[1].strip().lower()
|
| 228 |
+
if "high" in conf_text:
|
| 229 |
+
confidence = 0.85
|
| 230 |
+
elif "medium" in conf_text:
|
| 231 |
+
confidence = 0.6
|
| 232 |
+
elif "low" in conf_text:
|
| 233 |
+
confidence = 0.35
|
| 234 |
+
current_section = "confidence"
|
| 235 |
+
elif line.startswith("REASONING:"):
|
| 236 |
+
reasoning = line.split(":", 1)[1].strip()
|
| 237 |
+
current_section = "reasoning"
|
| 238 |
+
elif line.startswith("ALTERNATIVES:"):
|
| 239 |
+
alt_text = line.split(":", 1)[1].strip()
|
| 240 |
+
if alt_text:
|
| 241 |
+
alternatives = [(a.strip(), 0.0) for a in alt_text.split(",")]
|
| 242 |
+
current_section = "alternatives"
|
| 243 |
+
elif line.startswith("NOTABLE:"):
|
| 244 |
+
notable_text = line.split(":", 1)[1].strip().lower()
|
| 245 |
+
notable = "yes" in notable_text.split("-")[0]
|
| 246 |
+
if notable and "-" in notable_text:
|
| 247 |
+
notable_explanation = notable_text.split("-", 1)[1].strip()
|
| 248 |
+
current_section = "notable"
|
| 249 |
+
elif current_section == "reasoning" and line:
|
| 250 |
+
reasoning += " " + line
|
| 251 |
+
elif current_section == "alternatives" and line and line.startswith("-"):
|
| 252 |
+
alternatives.append((line[1:].strip(), 0.0))
|
| 253 |
+
|
| 254 |
+
# Find species ID
|
| 255 |
+
species_id = -1
|
| 256 |
+
species = self.species_db.get_by_name(assessment)
|
| 257 |
+
if species:
|
| 258 |
+
species_id = species.id
|
| 259 |
+
elif context.audio_predictions:
|
| 260 |
+
species_id = context.audio_predictions[0][0]
|
| 261 |
+
species = self.species_db.get_species(species_id)
|
| 262 |
+
if species:
|
| 263 |
+
assessment = species.common_name
|
| 264 |
+
|
| 265 |
+
return ReasoningResult(
|
| 266 |
+
species_id=species_id,
|
| 267 |
+
species_name=assessment,
|
| 268 |
+
confidence=confidence,
|
| 269 |
+
reasoning=reasoning,
|
| 270 |
+
alternative_species=alternatives,
|
| 271 |
+
novelty_flag=notable,
|
| 272 |
+
novelty_explanation=notable_explanation
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
async def reason_async(
|
| 276 |
+
self,
|
| 277 |
+
context: ReasoningContext
|
| 278 |
+
) -> ReasoningResult:
|
| 279 |
+
"""Async version of reason()."""
|
| 280 |
+
prompt = self._build_reasoning_prompt(context)
|
| 281 |
+
|
| 282 |
+
async with OllamaClient(self.ollama_config) as client:
|
| 283 |
+
response = await client.generate(
|
| 284 |
+
prompt=prompt,
|
| 285 |
+
system_prompt=SYSTEM_PROMPT
|
| 286 |
+
)
|
| 287 |
+
return self._parse_response(response, context)
|
| 288 |
+
|
| 289 |
+
def generate_description(
|
| 290 |
+
self,
|
| 291 |
+
species_id: int,
|
| 292 |
+
include_calls: bool = True,
|
| 293 |
+
include_habitat: bool = True
|
| 294 |
+
) -> str:
|
| 295 |
+
"""
|
| 296 |
+
Generate natural language description of a species.
|
| 297 |
+
|
| 298 |
+
Useful for educational purposes and matching user descriptions.
|
| 299 |
+
"""
|
| 300 |
+
species = self.species_db.get_species(species_id)
|
| 301 |
+
if not species:
|
| 302 |
+
return "Species not found."
|
| 303 |
+
|
| 304 |
+
prompt = f"""Generate a brief, informative description of the {species.common_name}
|
| 305 |
+
({species.scientific_name}) for birdwatchers in India.
|
| 306 |
+
|
| 307 |
+
Species information:
|
| 308 |
+
{self.species_db.get_species_for_llm_context(species_id)}
|
| 309 |
+
|
| 310 |
+
Include:
|
| 311 |
+
- Key identifying features
|
| 312 |
+
{"- Distinctive calls and songs" if include_calls else ""}
|
| 313 |
+
{"- Typical habitat and where to find it" if include_habitat else ""}
|
| 314 |
+
- Interesting facts
|
| 315 |
+
|
| 316 |
+
Keep it concise (2-3 paragraphs)."""
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
return self.sync_client.generate(prompt=prompt)
|
| 320 |
+
except Exception as e:
|
| 321 |
+
# Fallback to database info
|
| 322 |
+
return self.species_db.get_species_for_llm_context(species_id)
|
| 323 |
+
|
| 324 |
+
def match_description(
|
| 325 |
+
self,
|
| 326 |
+
user_description: str,
|
| 327 |
+
candidates: Optional[List[int]] = None
|
| 328 |
+
) -> List[Tuple[int, float, str]]:
|
| 329 |
+
"""
|
| 330 |
+
Match user description to species.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
user_description: User's description of the bird
|
| 334 |
+
candidates: Optional list of species IDs to consider
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
List of (species_id, match_score, explanation)
|
| 338 |
+
"""
|
| 339 |
+
if candidates is None:
|
| 340 |
+
candidates = list(self.species_db.species.keys())
|
| 341 |
+
|
| 342 |
+
# Build context for matching
|
| 343 |
+
species_info = []
|
| 344 |
+
for species_id in candidates[:20]: # Limit for efficiency
|
| 345 |
+
species = self.species_db.get_species(species_id)
|
| 346 |
+
if species:
|
| 347 |
+
species_info.append(f"- {species.common_name}: {species.call_description}")
|
| 348 |
+
|
| 349 |
+
prompt = f"""Match this bird description to the most likely species:
|
| 350 |
+
|
| 351 |
+
User Description: "{user_description}"
|
| 352 |
+
|
| 353 |
+
Candidate Species:
|
| 354 |
+
{chr(10).join(species_info)}
|
| 355 |
+
|
| 356 |
+
List the top 3 matches with confidence (0-100%) and brief explanation:
|
| 357 |
+
Format: [Species Name] - [confidence]% - [reason]"""
|
| 358 |
+
|
| 359 |
+
try:
|
| 360 |
+
response = self.sync_client.generate(prompt=prompt)
|
| 361 |
+
|
| 362 |
+
# Parse matches from response
|
| 363 |
+
matches = []
|
| 364 |
+
for line in response.split('\n'):
|
| 365 |
+
if '-' in line and '%' in line:
|
| 366 |
+
parts = line.split('-')
|
| 367 |
+
if len(parts) >= 2:
|
| 368 |
+
name = parts[0].strip().lstrip('0123456789. ')
|
| 369 |
+
species = self.species_db.get_by_name(name)
|
| 370 |
+
if species:
|
| 371 |
+
# Extract confidence
|
| 372 |
+
conf_part = parts[1] if len(parts) > 1 else ""
|
| 373 |
+
try:
|
| 374 |
+
conf = float(''.join(c for c in conf_part if c.isdigit())) / 100
|
| 375 |
+
except ValueError:
|
| 376 |
+
conf = 0.5
|
| 377 |
+
explanation = parts[2].strip() if len(parts) > 2 else ""
|
| 378 |
+
matches.append((species.id, min(1.0, conf), explanation))
|
| 379 |
+
|
| 380 |
+
return matches
|
| 381 |
+
|
| 382 |
+
except Exception:
|
| 383 |
+
return []
|
| 384 |
+
|
| 385 |
+
def check_ollama_status(self) -> Dict[str, any]:
|
| 386 |
+
"""Check Ollama server and model status."""
|
| 387 |
+
try:
|
| 388 |
+
is_healthy = self.sync_client.health_check()
|
| 389 |
+
is_model_available = self.sync_client.is_model_available()
|
| 390 |
+
|
| 391 |
+
return {
|
| 392 |
+
"server_running": is_healthy,
|
| 393 |
+
"model_available": is_model_available,
|
| 394 |
+
"model_name": self.ollama_config.model,
|
| 395 |
+
"status": "ready" if (is_healthy and is_model_available) else "not_ready"
|
| 396 |
+
}
|
| 397 |
+
except Exception as e:
|
| 398 |
+
return {
|
| 399 |
+
"server_running": False,
|
| 400 |
+
"model_available": False,
|
| 401 |
+
"model_name": self.ollama_config.model,
|
| 402 |
+
"status": "error",
|
| 403 |
+
"error": str(e)
|
| 404 |
+
}
|
| 405 |
+
|
llm/zero_shot_identifier.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Zero-Shot Bird Identification using LLM.
|
| 3 |
+
|
| 4 |
+
This is the CORE innovation: Instead of training on every bird,
|
| 5 |
+
we use the LLM's knowledge to identify ANY bird from audio features.
|
| 6 |
+
|
| 7 |
+
The LLM has learned about thousands of bird species from its training data,
|
| 8 |
+
including their calls, habitats, and behaviors.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from .ollama_client import OllamaClient, OllamaConfig
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class AudioFeatures:
|
| 24 |
+
"""Extracted audio features for LLM analysis."""
|
| 25 |
+
duration: float
|
| 26 |
+
dominant_frequency_hz: float
|
| 27 |
+
frequency_range: Tuple[float, float]
|
| 28 |
+
spectral_centroid: float
|
| 29 |
+
spectral_bandwidth: float
|
| 30 |
+
tempo_bpm: float
|
| 31 |
+
num_syllables: int
|
| 32 |
+
syllable_rate: float # syllables per second
|
| 33 |
+
is_melodic: bool
|
| 34 |
+
is_repetitive: bool
|
| 35 |
+
amplitude_pattern: str # "constant", "rising", "falling", "varied"
|
| 36 |
+
estimated_snr_db: float
|
| 37 |
+
quality_score: float
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ZeroShotResult:
|
| 42 |
+
"""Result from zero-shot identification."""
|
| 43 |
+
species_name: str
|
| 44 |
+
scientific_name: str
|
| 45 |
+
confidence: float # 0.0 to 1.0
|
| 46 |
+
confidence_label: str # "high", "medium", "low"
|
| 47 |
+
reasoning: str
|
| 48 |
+
key_features_matched: List[str]
|
| 49 |
+
alternative_species: List[Dict[str, Any]]
|
| 50 |
+
is_indian_bird: bool
|
| 51 |
+
is_unusual_sighting: bool
|
| 52 |
+
unusual_reason: Optional[str]
|
| 53 |
+
call_description: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ZeroShotBirdIdentifier:
|
| 57 |
+
"""
|
| 58 |
+
Zero-shot bird identification using LLM.
|
| 59 |
+
|
| 60 |
+
This approach:
|
| 61 |
+
1. Extracts audio features (frequency, pattern, duration)
|
| 62 |
+
2. Sends features to LLM with expert prompt
|
| 63 |
+
3. LLM identifies bird from its knowledge base
|
| 64 |
+
4. Returns species with confidence and reasoning
|
| 65 |
+
|
| 66 |
+
Benefits:
|
| 67 |
+
- No training required
|
| 68 |
+
- Can identify ANY of 10,000+ bird species
|
| 69 |
+
- Works for non-Indian birds too (with novelty flag)
|
| 70 |
+
- Explainable results
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, ollama_config: Optional[OllamaConfig] = None):
|
| 74 |
+
self.ollama = OllamaClient(ollama_config or OllamaConfig(model="qwen2.5:3b"))
|
| 75 |
+
self.is_ready = False
|
| 76 |
+
|
| 77 |
+
def initialize(self) -> bool:
|
| 78 |
+
"""Check if LLM is available."""
|
| 79 |
+
try:
|
| 80 |
+
status = self.ollama.check_status()
|
| 81 |
+
self.is_ready = status.get("status") == "ready"
|
| 82 |
+
return self.is_ready
|
| 83 |
+
except:
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
def extract_features(
|
| 87 |
+
self,
|
| 88 |
+
audio: np.ndarray,
|
| 89 |
+
sample_rate: int = 32000,
|
| 90 |
+
mel_spec: Optional[np.ndarray] = None
|
| 91 |
+
) -> AudioFeatures:
|
| 92 |
+
"""Extract audio features for LLM analysis."""
|
| 93 |
+
import scipy.signal as signal
|
| 94 |
+
|
| 95 |
+
duration = len(audio) / sample_rate
|
| 96 |
+
|
| 97 |
+
# Frequency analysis
|
| 98 |
+
freqs, psd = signal.welch(audio, sample_rate, nperseg=2048)
|
| 99 |
+
|
| 100 |
+
# Dominant frequency
|
| 101 |
+
dominant_idx = np.argmax(psd)
|
| 102 |
+
dominant_freq = freqs[dominant_idx]
|
| 103 |
+
|
| 104 |
+
# Frequency range (where 90% of energy is)
|
| 105 |
+
cumsum = np.cumsum(psd) / np.sum(psd)
|
| 106 |
+
freq_low = freqs[np.searchsorted(cumsum, 0.05)]
|
| 107 |
+
freq_high = freqs[np.searchsorted(cumsum, 0.95)]
|
| 108 |
+
|
| 109 |
+
# Spectral centroid
|
| 110 |
+
spectral_centroid = np.sum(freqs * psd) / (np.sum(psd) + 1e-10)
|
| 111 |
+
|
| 112 |
+
# Spectral bandwidth
|
| 113 |
+
spectral_bandwidth = np.sqrt(np.sum(((freqs - spectral_centroid) ** 2) * psd) / (np.sum(psd) + 1e-10))
|
| 114 |
+
|
| 115 |
+
# Amplitude envelope analysis
|
| 116 |
+
envelope = np.abs(signal.hilbert(audio))
|
| 117 |
+
envelope_smooth = signal.medfilt(envelope, 1001)
|
| 118 |
+
|
| 119 |
+
# Detect syllables (peaks in envelope)
|
| 120 |
+
peaks, _ = signal.find_peaks(envelope_smooth, height=0.1 * np.max(envelope_smooth), distance=sample_rate // 10)
|
| 121 |
+
num_syllables = len(peaks)
|
| 122 |
+
syllable_rate = num_syllables / duration if duration > 0 else 0
|
| 123 |
+
|
| 124 |
+
# Amplitude pattern
|
| 125 |
+
if len(envelope_smooth) > 100:
|
| 126 |
+
start_amp = np.mean(envelope_smooth[:len(envelope_smooth)//4])
|
| 127 |
+
end_amp = np.mean(envelope_smooth[-len(envelope_smooth)//4:])
|
| 128 |
+
amp_var = np.std(envelope_smooth) / (np.mean(envelope_smooth) + 1e-10)
|
| 129 |
+
|
| 130 |
+
if amp_var > 0.5:
|
| 131 |
+
amp_pattern = "varied"
|
| 132 |
+
elif end_amp > start_amp * 1.3:
|
| 133 |
+
amp_pattern = "rising"
|
| 134 |
+
elif end_amp < start_amp * 0.7:
|
| 135 |
+
amp_pattern = "falling"
|
| 136 |
+
else:
|
| 137 |
+
amp_pattern = "constant"
|
| 138 |
+
else:
|
| 139 |
+
amp_pattern = "constant"
|
| 140 |
+
|
| 141 |
+
# Melodic detection (frequency variation)
|
| 142 |
+
if len(audio) > sample_rate:
|
| 143 |
+
chunks = np.array_split(audio, 10)
|
| 144 |
+
chunk_freqs = []
|
| 145 |
+
for chunk in chunks:
|
| 146 |
+
if len(chunk) > 512:
|
| 147 |
+
f, p = signal.welch(chunk, sample_rate, nperseg=512)
|
| 148 |
+
chunk_freqs.append(f[np.argmax(p)])
|
| 149 |
+
freq_variation = np.std(chunk_freqs) / (np.mean(chunk_freqs) + 1e-10)
|
| 150 |
+
is_melodic = freq_variation > 0.1
|
| 151 |
+
else:
|
| 152 |
+
is_melodic = False
|
| 153 |
+
|
| 154 |
+
# Repetitiveness detection
|
| 155 |
+
if num_syllables >= 3:
|
| 156 |
+
if syllable_rate > 1.5 and syllable_rate < 10: # Regular pattern
|
| 157 |
+
is_repetitive = True
|
| 158 |
+
else:
|
| 159 |
+
is_repetitive = False
|
| 160 |
+
else:
|
| 161 |
+
is_repetitive = num_syllables >= 2
|
| 162 |
+
|
| 163 |
+
# SNR estimation
|
| 164 |
+
noise_floor = np.percentile(np.abs(audio), 10)
|
| 165 |
+
signal_peak = np.percentile(np.abs(audio), 95)
|
| 166 |
+
snr_db = 20 * np.log10((signal_peak + 1e-10) / (noise_floor + 1e-10))
|
| 167 |
+
|
| 168 |
+
# Quality score
|
| 169 |
+
quality_score = min(1.0, max(0.0, (snr_db - 5) / 25))
|
| 170 |
+
|
| 171 |
+
# Tempo (for rhythmic calls)
|
| 172 |
+
if num_syllables >= 2:
|
| 173 |
+
tempo_bpm = syllable_rate * 60
|
| 174 |
+
else:
|
| 175 |
+
tempo_bpm = 0
|
| 176 |
+
|
| 177 |
+
return AudioFeatures(
|
| 178 |
+
duration=duration,
|
| 179 |
+
dominant_frequency_hz=float(dominant_freq),
|
| 180 |
+
frequency_range=(float(freq_low), float(freq_high)),
|
| 181 |
+
spectral_centroid=float(spectral_centroid),
|
| 182 |
+
spectral_bandwidth=float(spectral_bandwidth),
|
| 183 |
+
tempo_bpm=float(tempo_bpm),
|
| 184 |
+
num_syllables=num_syllables,
|
| 185 |
+
syllable_rate=float(syllable_rate),
|
| 186 |
+
is_melodic=is_melodic,
|
| 187 |
+
is_repetitive=is_repetitive,
|
| 188 |
+
amplitude_pattern=amp_pattern,
|
| 189 |
+
estimated_snr_db=float(snr_db),
|
| 190 |
+
quality_score=float(quality_score)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def identify(
|
| 194 |
+
self,
|
| 195 |
+
features: AudioFeatures,
|
| 196 |
+
location: Optional[str] = None,
|
| 197 |
+
month: Optional[int] = None,
|
| 198 |
+
user_description: Optional[str] = None
|
| 199 |
+
) -> ZeroShotResult:
|
| 200 |
+
"""
|
| 201 |
+
Identify bird species using zero-shot LLM inference.
|
| 202 |
+
|
| 203 |
+
This is the NOVEL approach - using LLM's knowledge to identify
|
| 204 |
+
any bird without needing to train on that specific species.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
# Build expert prompt
|
| 208 |
+
prompt = self._build_identification_prompt(features, location, month, user_description)
|
| 209 |
+
|
| 210 |
+
# Call LLM (synchronously using asyncio)
|
| 211 |
+
try:
|
| 212 |
+
import asyncio
|
| 213 |
+
|
| 214 |
+
async def _generate():
|
| 215 |
+
return await self.ollama.generate(
|
| 216 |
+
prompt,
|
| 217 |
+
system_prompt=self._get_expert_system_prompt(),
|
| 218 |
+
temperature=0.3, # Lower for more deterministic
|
| 219 |
+
max_tokens=1000
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Run async in sync context
|
| 223 |
+
try:
|
| 224 |
+
loop = asyncio.get_event_loop()
|
| 225 |
+
if loop.is_running():
|
| 226 |
+
# Use nest_asyncio for nested event loops
|
| 227 |
+
import nest_asyncio
|
| 228 |
+
nest_asyncio.apply()
|
| 229 |
+
response = loop.run_until_complete(_generate())
|
| 230 |
+
except RuntimeError:
|
| 231 |
+
# No event loop running
|
| 232 |
+
response = asyncio.run(_generate())
|
| 233 |
+
|
| 234 |
+
# Parse response
|
| 235 |
+
return self._parse_identification_response(response, features)
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.error(f"LLM identification failed: {e}")
|
| 239 |
+
return self._fallback_result(features)
|
| 240 |
+
|
| 241 |
+
def _get_expert_system_prompt(self) -> str:
|
| 242 |
+
"""Expert ornithologist system prompt."""
|
| 243 |
+
return """You are an expert ornithologist with deep knowledge of bird vocalizations worldwide.
|
| 244 |
+
You can identify birds by their calls based on frequency, pattern, duration, and context.
|
| 245 |
+
|
| 246 |
+
Your expertise includes:
|
| 247 |
+
- 10,000+ bird species globally
|
| 248 |
+
- Detailed knowledge of Indian birds (1,300+ species)
|
| 249 |
+
- Ability to distinguish similar-sounding species
|
| 250 |
+
- Understanding of seasonal and geographic variations
|
| 251 |
+
|
| 252 |
+
When identifying birds:
|
| 253 |
+
1. Consider the audio characteristics carefully
|
| 254 |
+
2. Match against known bird call patterns
|
| 255 |
+
3. Account for regional variations
|
| 256 |
+
4. Flag unusual or rare sightings
|
| 257 |
+
5. Provide confidence based on how well features match
|
| 258 |
+
|
| 259 |
+
Always respond in the exact JSON format requested."""
|
| 260 |
+
|
| 261 |
+
def _build_identification_prompt(
|
| 262 |
+
self,
|
| 263 |
+
features: AudioFeatures,
|
| 264 |
+
location: Optional[str],
|
| 265 |
+
month: Optional[int],
|
| 266 |
+
user_description: Optional[str]
|
| 267 |
+
) -> str:
|
| 268 |
+
"""Build identification prompt from audio features."""
|
| 269 |
+
|
| 270 |
+
# Describe frequency in bird call terms
|
| 271 |
+
freq_desc = self._describe_frequency(features.dominant_frequency_hz)
|
| 272 |
+
|
| 273 |
+
# Season
|
| 274 |
+
season = self._get_season(month) if month else "unknown"
|
| 275 |
+
|
| 276 |
+
prompt = f"""Identify this bird based on its call characteristics:
|
| 277 |
+
|
| 278 |
+
## Audio Features
|
| 279 |
+
- **Duration**: {features.duration:.1f} seconds
|
| 280 |
+
- **Dominant Frequency**: {features.dominant_frequency_hz:.0f} Hz ({freq_desc})
|
| 281 |
+
- **Frequency Range**: {features.frequency_range[0]:.0f} - {features.frequency_range[1]:.0f} Hz
|
| 282 |
+
- **Call Pattern**: {"Melodic/varied" if features.is_melodic else "Monotone"}, {"Repetitive" if features.is_repetitive else "Non-repetitive"}
|
| 283 |
+
- **Syllables**: {features.num_syllables} syllables at {features.syllable_rate:.1f}/second
|
| 284 |
+
- **Rhythm**: {features.tempo_bpm:.0f} BPM (beats per minute)
|
| 285 |
+
- **Amplitude**: {features.amplitude_pattern} pattern
|
| 286 |
+
|
| 287 |
+
## Context
|
| 288 |
+
- **Location**: {location or "India (unspecified)"}
|
| 289 |
+
- **Season**: {season}
|
| 290 |
+
- **Recording Quality**: {self._quality_label(features.quality_score)} (SNR: {features.estimated_snr_db:.0f}dB)
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
if user_description:
|
| 294 |
+
prompt += f"- **Observer Notes**: {user_description}\n"
|
| 295 |
+
|
| 296 |
+
prompt += """
|
| 297 |
+
## Task
|
| 298 |
+
Based on these audio features, identify the most likely bird species.
|
| 299 |
+
|
| 300 |
+
Respond in this exact JSON format:
|
| 301 |
+
{
|
| 302 |
+
"species_name": "Common Name",
|
| 303 |
+
"scientific_name": "Genus species",
|
| 304 |
+
"confidence": 0.85,
|
| 305 |
+
"reasoning": "Detailed explanation of why this species matches...",
|
| 306 |
+
"key_features_matched": ["feature1", "feature2"],
|
| 307 |
+
"alternatives": [
|
| 308 |
+
{"name": "Alternative 1", "scientific": "Genus species", "confidence": 0.1},
|
| 309 |
+
{"name": "Alternative 2", "scientific": "Genus species", "confidence": 0.05}
|
| 310 |
+
],
|
| 311 |
+
"is_indian_bird": true,
|
| 312 |
+
"is_unusual": false,
|
| 313 |
+
"unusual_reason": null,
|
| 314 |
+
"typical_call": "Description of what this bird typically sounds like"
|
| 315 |
+
}"""
|
| 316 |
+
|
| 317 |
+
return prompt
|
| 318 |
+
|
| 319 |
+
def _describe_frequency(self, freq: float) -> str:
|
| 320 |
+
"""Describe frequency in bird call terms."""
|
| 321 |
+
if freq < 500:
|
| 322 |
+
return "very low (large bird or booming call)"
|
| 323 |
+
elif freq < 1000:
|
| 324 |
+
return "low (owl, dove, or large bird)"
|
| 325 |
+
elif freq < 2000:
|
| 326 |
+
return "low-medium (cuckoo, crow, or medium bird)"
|
| 327 |
+
elif freq < 4000:
|
| 328 |
+
return "medium (most songbirds)"
|
| 329 |
+
elif freq < 6000:
|
| 330 |
+
return "medium-high (warbler, sunbird)"
|
| 331 |
+
elif freq < 8000:
|
| 332 |
+
return "high (small passerine)"
|
| 333 |
+
else:
|
| 334 |
+
return "very high (insect-like or whistle)"
|
| 335 |
+
|
| 336 |
+
def _get_season(self, month: int) -> str:
|
| 337 |
+
"""Get Indian season from month."""
|
| 338 |
+
if month in [12, 1, 2]:
|
| 339 |
+
return "winter (Dec-Feb) - winter migrants present"
|
| 340 |
+
elif month in [3, 4, 5]:
|
| 341 |
+
return "summer/pre-monsoon (Mar-May) - breeding season"
|
| 342 |
+
elif month in [6, 7, 8, 9]:
|
| 343 |
+
return "monsoon (Jun-Sep)"
|
| 344 |
+
else:
|
| 345 |
+
return "post-monsoon (Oct-Nov) - migration period"
|
| 346 |
+
|
| 347 |
+
def _quality_label(self, score: float) -> str:
|
| 348 |
+
"""Convert quality score to label."""
|
| 349 |
+
if score > 0.8:
|
| 350 |
+
return "excellent"
|
| 351 |
+
elif score > 0.6:
|
| 352 |
+
return "good"
|
| 353 |
+
elif score > 0.4:
|
| 354 |
+
return "fair"
|
| 355 |
+
else:
|
| 356 |
+
return "poor"
|
| 357 |
+
|
| 358 |
+
def _parse_identification_response(
|
| 359 |
+
self,
|
| 360 |
+
response: str,
|
| 361 |
+
features: AudioFeatures
|
| 362 |
+
) -> ZeroShotResult:
|
| 363 |
+
"""Parse LLM response into structured result."""
|
| 364 |
+
try:
|
| 365 |
+
# Try to extract JSON from response
|
| 366 |
+
json_start = response.find('{')
|
| 367 |
+
json_end = response.rfind('}') + 1
|
| 368 |
+
|
| 369 |
+
if json_start >= 0 and json_end > json_start:
|
| 370 |
+
json_str = response[json_start:json_end]
|
| 371 |
+
data = json.loads(json_str)
|
| 372 |
+
|
| 373 |
+
confidence = float(data.get('confidence', 0.5))
|
| 374 |
+
|
| 375 |
+
return ZeroShotResult(
|
| 376 |
+
species_name=data.get('species_name', 'Unknown'),
|
| 377 |
+
scientific_name=data.get('scientific_name', ''),
|
| 378 |
+
confidence=confidence,
|
| 379 |
+
confidence_label=self._confidence_label(confidence),
|
| 380 |
+
reasoning=data.get('reasoning', ''),
|
| 381 |
+
key_features_matched=data.get('key_features_matched', []),
|
| 382 |
+
alternative_species=data.get('alternatives', []),
|
| 383 |
+
is_indian_bird=data.get('is_indian_bird', True),
|
| 384 |
+
is_unusual_sighting=data.get('is_unusual', False),
|
| 385 |
+
unusual_reason=data.get('unusual_reason'),
|
| 386 |
+
call_description=data.get('typical_call', '')
|
| 387 |
+
)
|
| 388 |
+
except json.JSONDecodeError as e:
|
| 389 |
+
logger.warning(f"Failed to parse LLM JSON: {e}")
|
| 390 |
+
|
| 391 |
+
# Fallback: try to extract species name from text
|
| 392 |
+
return self._fallback_result(features, response)
|
| 393 |
+
|
| 394 |
+
def _confidence_label(self, confidence: float) -> str:
|
| 395 |
+
"""Convert confidence to label."""
|
| 396 |
+
if confidence >= 0.8:
|
| 397 |
+
return "high"
|
| 398 |
+
elif confidence >= 0.6:
|
| 399 |
+
return "medium"
|
| 400 |
+
else:
|
| 401 |
+
return "low"
|
| 402 |
+
|
| 403 |
+
def _fallback_result(
|
| 404 |
+
self,
|
| 405 |
+
features: AudioFeatures,
|
| 406 |
+
llm_response: str = ""
|
| 407 |
+
) -> ZeroShotResult:
|
| 408 |
+
"""Generate fallback result when LLM parsing fails."""
|
| 409 |
+
|
| 410 |
+
# Try to guess based on frequency
|
| 411 |
+
if features.dominant_frequency_hz < 1000:
|
| 412 |
+
if features.is_repetitive:
|
| 413 |
+
species = "Spotted Owlet"
|
| 414 |
+
scientific = "Athene brama"
|
| 415 |
+
else:
|
| 416 |
+
species = "Indian Cuckoo"
|
| 417 |
+
scientific = "Cuculus micropterus"
|
| 418 |
+
elif features.dominant_frequency_hz < 3000:
|
| 419 |
+
if features.is_melodic:
|
| 420 |
+
species = "Oriental Magpie-Robin"
|
| 421 |
+
scientific = "Copsychus saularis"
|
| 422 |
+
else:
|
| 423 |
+
species = "Asian Koel"
|
| 424 |
+
scientific = "Eudynamys scolopaceus"
|
| 425 |
+
else:
|
| 426 |
+
if features.syllable_rate > 3:
|
| 427 |
+
species = "Coppersmith Barbet"
|
| 428 |
+
scientific = "Psilopogon haemacephalus"
|
| 429 |
+
else:
|
| 430 |
+
species = "Common Tailorbird"
|
| 431 |
+
scientific = "Orthotomus sutorius"
|
| 432 |
+
|
| 433 |
+
return ZeroShotResult(
|
| 434 |
+
species_name=species,
|
| 435 |
+
scientific_name=scientific,
|
| 436 |
+
confidence=0.4,
|
| 437 |
+
confidence_label="low",
|
| 438 |
+
reasoning="Identification based on audio frequency and pattern analysis. LLM analysis unavailable.",
|
| 439 |
+
key_features_matched=["frequency range", "call pattern"],
|
| 440 |
+
alternative_species=[],
|
| 441 |
+
is_indian_bird=True,
|
| 442 |
+
is_unusual_sighting=False,
|
| 443 |
+
unusual_reason=None,
|
| 444 |
+
call_description=""
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# Global instance for quick access
|
| 449 |
+
_identifier: Optional[ZeroShotBirdIdentifier] = None
|
| 450 |
+
|
| 451 |
+
def get_zero_shot_identifier() -> ZeroShotBirdIdentifier:
|
| 452 |
+
"""Get or create global zero-shot identifier."""
|
| 453 |
+
global _identifier
|
| 454 |
+
if _identifier is None:
|
| 455 |
+
_identifier = ZeroShotBirdIdentifier()
|
| 456 |
+
return _identifier
|
| 457 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BirdSense Models Module."""
|
| 2 |
+
|
| 3 |
+
from .audio_classifier import BirdAudioClassifier
|
| 4 |
+
from .novelty_detector import NoveltyDetector
|
| 5 |
+
|
| 6 |
+
__all__ = ["BirdAudioClassifier", "NoveltyDetector"]
|
| 7 |
+
|
models/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (363 Bytes). View file
|
|
|
models/__pycache__/audio_classifier.cpython-314.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
models/__pycache__/novelty_detector.cpython-314.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
models/audio_classifier.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bird Audio Classifier for BirdSense.
|
| 3 |
+
|
| 4 |
+
Complete classification pipeline from audio to species prediction.
|
| 5 |
+
Combines the audio encoder with a classification head and
|
| 6 |
+
optional LLM reasoning for enhanced accuracy.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from typing import Optional, List, Dict, Tuple
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from ..audio.encoder import AudioEncoder
|
| 17 |
+
except ImportError:
|
| 18 |
+
from audio.encoder import AudioEncoder
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ClassificationHead(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Classification head with dropout and layer norm.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
input_dim: int,
|
| 29 |
+
num_classes: int,
|
| 30 |
+
hidden_dims: List[int] = [256, 128],
|
| 31 |
+
dropout: float = 0.3
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
layers = []
|
| 36 |
+
in_dim = input_dim
|
| 37 |
+
|
| 38 |
+
for h_dim in hidden_dims:
|
| 39 |
+
layers.extend([
|
| 40 |
+
nn.Linear(in_dim, h_dim),
|
| 41 |
+
nn.LayerNorm(h_dim),
|
| 42 |
+
nn.GELU(),
|
| 43 |
+
nn.Dropout(dropout)
|
| 44 |
+
])
|
| 45 |
+
in_dim = h_dim
|
| 46 |
+
|
| 47 |
+
layers.append(nn.Linear(in_dim, num_classes))
|
| 48 |
+
|
| 49 |
+
self.classifier = nn.Sequential(*layers)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
return self.classifier(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BirdAudioClassifier(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
Complete bird audio classification model.
|
| 58 |
+
|
| 59 |
+
Combines:
|
| 60 |
+
- Audio encoder (CNN or Transformer)
|
| 61 |
+
- Classification head
|
| 62 |
+
- Uncertainty estimation
|
| 63 |
+
|
| 64 |
+
Designed for robust bird species identification from audio.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
num_classes: int = 250,
|
| 70 |
+
encoder_architecture: str = 'cnn',
|
| 71 |
+
n_mels: int = 128,
|
| 72 |
+
embedding_dim: int = 384,
|
| 73 |
+
hidden_dims: List[int] = [256, 128],
|
| 74 |
+
dropout: float = 0.3,
|
| 75 |
+
pretrained_encoder: bool = False
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.num_classes = num_classes
|
| 80 |
+
self.embedding_dim = embedding_dim
|
| 81 |
+
|
| 82 |
+
# Audio encoder
|
| 83 |
+
self.encoder = AudioEncoder(
|
| 84 |
+
architecture=encoder_architecture,
|
| 85 |
+
n_mels=n_mels,
|
| 86 |
+
embedding_dim=embedding_dim,
|
| 87 |
+
pretrained=pretrained_encoder
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Classification head
|
| 91 |
+
self.classifier = ClassificationHead(
|
| 92 |
+
input_dim=embedding_dim,
|
| 93 |
+
num_classes=num_classes,
|
| 94 |
+
hidden_dims=hidden_dims,
|
| 95 |
+
dropout=dropout
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Temperature for calibrated probabilities
|
| 99 |
+
self.temperature = nn.Parameter(torch.ones(1))
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
x: torch.Tensor,
|
| 104 |
+
return_embeddings: bool = False
|
| 105 |
+
) -> Dict[str, torch.Tensor]:
|
| 106 |
+
"""
|
| 107 |
+
Forward pass.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
x: Mel-spectrogram (batch, n_mels, n_frames)
|
| 111 |
+
return_embeddings: Whether to return intermediate embeddings
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Dictionary with:
|
| 115 |
+
- logits: Raw classification scores
|
| 116 |
+
- probabilities: Softmax probabilities
|
| 117 |
+
- embeddings: (optional) Audio embeddings
|
| 118 |
+
"""
|
| 119 |
+
# Extract embeddings
|
| 120 |
+
embeddings = self.encoder(x)
|
| 121 |
+
|
| 122 |
+
# Classify
|
| 123 |
+
logits = self.classifier(embeddings)
|
| 124 |
+
|
| 125 |
+
# Temperature-scaled probabilities
|
| 126 |
+
probabilities = F.softmax(logits / self.temperature, dim=-1)
|
| 127 |
+
|
| 128 |
+
output = {
|
| 129 |
+
"logits": logits,
|
| 130 |
+
"probabilities": probabilities
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
if return_embeddings:
|
| 134 |
+
output["embeddings"] = embeddings
|
| 135 |
+
|
| 136 |
+
return output
|
| 137 |
+
|
| 138 |
+
def predict(
|
| 139 |
+
self,
|
| 140 |
+
x: torch.Tensor,
|
| 141 |
+
top_k: int = 5
|
| 142 |
+
) -> Dict[str, torch.Tensor]:
|
| 143 |
+
"""
|
| 144 |
+
Get top-k predictions with confidence scores.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
x: Mel-spectrogram input
|
| 148 |
+
top_k: Number of top predictions to return
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Dictionary with:
|
| 152 |
+
- top_indices: Indices of top-k classes
|
| 153 |
+
- top_probabilities: Probabilities of top-k classes
|
| 154 |
+
- max_confidence: Confidence of top prediction
|
| 155 |
+
- uncertainty: Entropy-based uncertainty
|
| 156 |
+
"""
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
output = self.forward(x, return_embeddings=True)
|
| 159 |
+
probs = output["probabilities"]
|
| 160 |
+
|
| 161 |
+
# Top-k predictions
|
| 162 |
+
top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.size(-1)), dim=-1)
|
| 163 |
+
|
| 164 |
+
# Uncertainty (entropy)
|
| 165 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
|
| 166 |
+
max_entropy = np.log(self.num_classes)
|
| 167 |
+
uncertainty = entropy / max_entropy # Normalized [0, 1]
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"top_indices": top_indices,
|
| 171 |
+
"top_probabilities": top_probs,
|
| 172 |
+
"max_confidence": top_probs[:, 0],
|
| 173 |
+
"uncertainty": uncertainty,
|
| 174 |
+
"embeddings": output["embeddings"]
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
def get_embedding(self, x: torch.Tensor) -> torch.Tensor:
|
| 178 |
+
"""Extract audio embeddings without classification."""
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
return self.encoder(x)
|
| 181 |
+
|
| 182 |
+
def calibrate_temperature(
|
| 183 |
+
self,
|
| 184 |
+
val_loader,
|
| 185 |
+
device: str = 'cpu'
|
| 186 |
+
):
|
| 187 |
+
"""
|
| 188 |
+
Calibrate temperature using validation set.
|
| 189 |
+
Uses temperature scaling for better probability calibration.
|
| 190 |
+
"""
|
| 191 |
+
self.eval()
|
| 192 |
+
logits_list = []
|
| 193 |
+
labels_list = []
|
| 194 |
+
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for x, y in val_loader:
|
| 197 |
+
x = x.to(device)
|
| 198 |
+
output = self.forward(x)
|
| 199 |
+
logits_list.append(output["logits"].cpu())
|
| 200 |
+
labels_list.append(y)
|
| 201 |
+
|
| 202 |
+
logits = torch.cat(logits_list, dim=0)
|
| 203 |
+
labels = torch.cat(labels_list, dim=0)
|
| 204 |
+
|
| 205 |
+
# Find optimal temperature
|
| 206 |
+
best_temp = 1.0
|
| 207 |
+
best_nll = float('inf')
|
| 208 |
+
|
| 209 |
+
for temp in np.linspace(0.5, 3.0, 50):
|
| 210 |
+
scaled_logits = logits / temp
|
| 211 |
+
nll = F.cross_entropy(scaled_logits, labels)
|
| 212 |
+
if nll < best_nll:
|
| 213 |
+
best_nll = nll
|
| 214 |
+
best_temp = temp
|
| 215 |
+
|
| 216 |
+
self.temperature.data = torch.tensor([best_temp])
|
| 217 |
+
print(f"Calibrated temperature: {best_temp:.3f}")
|
| 218 |
+
|
| 219 |
+
def count_parameters(self) -> Dict[str, int]:
|
| 220 |
+
"""Count parameters in each component."""
|
| 221 |
+
encoder_params = sum(p.numel() for p in self.encoder.parameters())
|
| 222 |
+
classifier_params = sum(p.numel() for p in self.classifier.parameters())
|
| 223 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"encoder": encoder_params,
|
| 227 |
+
"classifier": classifier_params,
|
| 228 |
+
"total": total_params,
|
| 229 |
+
"total_mb": total_params * 4 / (1024 * 1024) # Assuming float32
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
def export_onnx(self, path: str, n_mels: int = 128, n_frames: int = 500):
|
| 233 |
+
"""Export model to ONNX format for mobile deployment."""
|
| 234 |
+
dummy_input = torch.randn(1, n_mels, n_frames)
|
| 235 |
+
|
| 236 |
+
torch.onnx.export(
|
| 237 |
+
self,
|
| 238 |
+
dummy_input,
|
| 239 |
+
path,
|
| 240 |
+
input_names=['mel_spectrogram'],
|
| 241 |
+
output_names=['logits', 'probabilities'],
|
| 242 |
+
dynamic_axes={
|
| 243 |
+
'mel_spectrogram': {0: 'batch', 2: 'frames'},
|
| 244 |
+
'logits': {0: 'batch'},
|
| 245 |
+
'probabilities': {0: 'batch'}
|
| 246 |
+
},
|
| 247 |
+
opset_version=14
|
| 248 |
+
)
|
| 249 |
+
print(f"Exported ONNX model to {path}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class EnsembleBirdClassifier(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
Ensemble of multiple classifiers for robust predictions.
|
| 255 |
+
|
| 256 |
+
Uses multiple architectures and combines predictions for
|
| 257 |
+
improved accuracy and calibration.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
num_classes: int = 250,
|
| 263 |
+
n_mels: int = 128,
|
| 264 |
+
embedding_dim: int = 384
|
| 265 |
+
):
|
| 266 |
+
super().__init__()
|
| 267 |
+
|
| 268 |
+
# Ensemble of different architectures
|
| 269 |
+
self.classifiers = nn.ModuleList([
|
| 270 |
+
BirdAudioClassifier(
|
| 271 |
+
num_classes=num_classes,
|
| 272 |
+
encoder_architecture='cnn',
|
| 273 |
+
n_mels=n_mels,
|
| 274 |
+
embedding_dim=embedding_dim
|
| 275 |
+
),
|
| 276 |
+
# Can add more architectures here
|
| 277 |
+
])
|
| 278 |
+
|
| 279 |
+
# Learnable ensemble weights
|
| 280 |
+
self.ensemble_weights = nn.Parameter(torch.ones(len(self.classifiers)))
|
| 281 |
+
|
| 282 |
+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 283 |
+
"""
|
| 284 |
+
Ensemble forward pass with weighted averaging.
|
| 285 |
+
"""
|
| 286 |
+
all_logits = []
|
| 287 |
+
all_embeddings = []
|
| 288 |
+
|
| 289 |
+
for classifier in self.classifiers:
|
| 290 |
+
output = classifier(x, return_embeddings=True)
|
| 291 |
+
all_logits.append(output["logits"])
|
| 292 |
+
all_embeddings.append(output["embeddings"])
|
| 293 |
+
|
| 294 |
+
# Weighted average
|
| 295 |
+
weights = F.softmax(self.ensemble_weights, dim=0)
|
| 296 |
+
logits_stack = torch.stack(all_logits, dim=0) # (n_models, batch, classes)
|
| 297 |
+
ensemble_logits = torch.sum(weights.view(-1, 1, 1) * logits_stack, dim=0)
|
| 298 |
+
|
| 299 |
+
probabilities = F.softmax(ensemble_logits, dim=-1)
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
"logits": ensemble_logits,
|
| 303 |
+
"probabilities": probabilities,
|
| 304 |
+
"embeddings": torch.mean(torch.stack(all_embeddings), dim=0),
|
| 305 |
+
"individual_logits": all_logits
|
| 306 |
+
}
|
| 307 |
+
|
models/novelty_detector.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Novelty Detection for BirdSense.
|
| 3 |
+
|
| 4 |
+
Detects out-of-distribution samples that might represent:
|
| 5 |
+
- New species not in training data
|
| 6 |
+
- Species outside their normal range
|
| 7 |
+
- Unusual vocalizations
|
| 8 |
+
- Recording artifacts or non-bird sounds
|
| 9 |
+
|
| 10 |
+
Uses embedding-space distance metrics for detection.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import numpy as np
|
| 16 |
+
from typing import Optional, Dict, Tuple, List
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class NoveltyResult:
|
| 23 |
+
"""Result of novelty detection."""
|
| 24 |
+
is_novel: bool
|
| 25 |
+
novelty_score: float # 0 = typical, 1 = very novel
|
| 26 |
+
nearest_class: int
|
| 27 |
+
nearest_distance: float
|
| 28 |
+
confidence: float
|
| 29 |
+
explanation: str
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class NoveltyDetector:
|
| 33 |
+
"""
|
| 34 |
+
Detects novel/out-of-distribution bird sounds.
|
| 35 |
+
|
| 36 |
+
Uses Mahalanobis distance in embedding space to identify
|
| 37 |
+
samples that don't match known species patterns.
|
| 38 |
+
|
| 39 |
+
Key features:
|
| 40 |
+
- Per-class covariance modeling
|
| 41 |
+
- Adaptive thresholding
|
| 42 |
+
- Geospatial prior integration (optional)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
embedding_dim: int = 384,
|
| 48 |
+
num_classes: int = 250,
|
| 49 |
+
threshold: float = 0.85
|
| 50 |
+
):
|
| 51 |
+
self.embedding_dim = embedding_dim
|
| 52 |
+
self.num_classes = num_classes
|
| 53 |
+
self.threshold = threshold
|
| 54 |
+
|
| 55 |
+
# Per-class statistics
|
| 56 |
+
self.class_means: Optional[torch.Tensor] = None # (num_classes, embedding_dim)
|
| 57 |
+
self.class_covariances: Optional[torch.Tensor] = None # (num_classes, embedding_dim, embedding_dim)
|
| 58 |
+
self.global_covariance: Optional[torch.Tensor] = None
|
| 59 |
+
self.is_fitted = False
|
| 60 |
+
|
| 61 |
+
# For Mahalanobis distance
|
| 62 |
+
self.precision_matrix: Optional[torch.Tensor] = None
|
| 63 |
+
|
| 64 |
+
def fit(
|
| 65 |
+
self,
|
| 66 |
+
embeddings: torch.Tensor,
|
| 67 |
+
labels: torch.Tensor,
|
| 68 |
+
regularization: float = 1e-5
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Fit the novelty detector on training embeddings.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
embeddings: Training embeddings (n_samples, embedding_dim)
|
| 75 |
+
labels: Class labels (n_samples,)
|
| 76 |
+
regularization: Regularization for covariance estimation
|
| 77 |
+
"""
|
| 78 |
+
embeddings = embeddings.cpu()
|
| 79 |
+
labels = labels.cpu()
|
| 80 |
+
|
| 81 |
+
n_classes = labels.max().item() + 1
|
| 82 |
+
|
| 83 |
+
# Compute per-class means
|
| 84 |
+
class_means = torch.zeros(n_classes, self.embedding_dim)
|
| 85 |
+
class_counts = torch.zeros(n_classes)
|
| 86 |
+
|
| 87 |
+
for emb, label in zip(embeddings, labels):
|
| 88 |
+
class_means[label] += emb
|
| 89 |
+
class_counts[label] += 1
|
| 90 |
+
|
| 91 |
+
# Avoid division by zero
|
| 92 |
+
class_counts = torch.clamp(class_counts, min=1)
|
| 93 |
+
class_means = class_means / class_counts.unsqueeze(1)
|
| 94 |
+
|
| 95 |
+
# Compute tied covariance (shared across classes for stability)
|
| 96 |
+
centered = embeddings - class_means[labels]
|
| 97 |
+
global_cov = (centered.T @ centered) / len(embeddings)
|
| 98 |
+
|
| 99 |
+
# Add regularization
|
| 100 |
+
global_cov += torch.eye(self.embedding_dim) * regularization
|
| 101 |
+
|
| 102 |
+
# Compute precision matrix (inverse covariance)
|
| 103 |
+
self.precision_matrix = torch.linalg.inv(global_cov)
|
| 104 |
+
self.class_means = class_means
|
| 105 |
+
self.global_covariance = global_cov
|
| 106 |
+
self.num_classes = n_classes
|
| 107 |
+
self.is_fitted = True
|
| 108 |
+
|
| 109 |
+
def mahalanobis_distance(
|
| 110 |
+
self,
|
| 111 |
+
embeddings: torch.Tensor,
|
| 112 |
+
class_idx: Optional[int] = None
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
"""
|
| 115 |
+
Compute Mahalanobis distance to class mean(s).
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
embeddings: Query embeddings (batch, embedding_dim)
|
| 119 |
+
class_idx: If specified, distance to specific class; otherwise min over all
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Distances (batch,) or (batch, num_classes)
|
| 123 |
+
"""
|
| 124 |
+
if not self.is_fitted:
|
| 125 |
+
raise RuntimeError("Novelty detector not fitted. Call fit() first.")
|
| 126 |
+
|
| 127 |
+
embeddings = embeddings.cpu()
|
| 128 |
+
|
| 129 |
+
if class_idx is not None:
|
| 130 |
+
# Distance to specific class
|
| 131 |
+
diff = embeddings - self.class_means[class_idx]
|
| 132 |
+
dist = torch.sqrt(torch.sum(diff @ self.precision_matrix * diff, dim=-1))
|
| 133 |
+
return dist
|
| 134 |
+
else:
|
| 135 |
+
# Distance to all classes
|
| 136 |
+
distances = []
|
| 137 |
+
for c in range(self.num_classes):
|
| 138 |
+
diff = embeddings - self.class_means[c]
|
| 139 |
+
dist = torch.sqrt(torch.sum(diff @ self.precision_matrix * diff, dim=-1))
|
| 140 |
+
distances.append(dist)
|
| 141 |
+
return torch.stack(distances, dim=-1) # (batch, num_classes)
|
| 142 |
+
|
| 143 |
+
def detect(
|
| 144 |
+
self,
|
| 145 |
+
embeddings: torch.Tensor,
|
| 146 |
+
predicted_class: Optional[torch.Tensor] = None,
|
| 147 |
+
species_names: Optional[List[str]] = None
|
| 148 |
+
) -> List[NoveltyResult]:
|
| 149 |
+
"""
|
| 150 |
+
Detect novelty in embeddings.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
embeddings: Query embeddings (batch, embedding_dim)
|
| 154 |
+
predicted_class: Predicted class indices (batch,)
|
| 155 |
+
species_names: Optional species name mapping
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
List of NoveltyResult for each sample
|
| 159 |
+
"""
|
| 160 |
+
if not self.is_fitted:
|
| 161 |
+
raise RuntimeError("Novelty detector not fitted. Call fit() first.")
|
| 162 |
+
|
| 163 |
+
# Compute distances to all classes
|
| 164 |
+
all_distances = self.mahalanobis_distance(embeddings) # (batch, num_classes)
|
| 165 |
+
|
| 166 |
+
# Find minimum distance and corresponding class
|
| 167 |
+
min_distances, nearest_classes = torch.min(all_distances, dim=-1)
|
| 168 |
+
|
| 169 |
+
# Normalize to [0, 1] novelty score
|
| 170 |
+
# Using sigmoid with empirically tuned scaling
|
| 171 |
+
novelty_scores = torch.sigmoid((min_distances - 3.0) / 1.0)
|
| 172 |
+
|
| 173 |
+
results = []
|
| 174 |
+
for i in range(len(embeddings)):
|
| 175 |
+
is_novel = novelty_scores[i].item() > self.threshold
|
| 176 |
+
nearest = nearest_classes[i].item()
|
| 177 |
+
|
| 178 |
+
if predicted_class is not None:
|
| 179 |
+
pred = predicted_class[i].item()
|
| 180 |
+
pred_distance = all_distances[i, pred].item()
|
| 181 |
+
else:
|
| 182 |
+
pred = nearest
|
| 183 |
+
pred_distance = min_distances[i].item()
|
| 184 |
+
|
| 185 |
+
# Generate explanation
|
| 186 |
+
if is_novel:
|
| 187 |
+
explanation = f"Sample appears novel (score: {novelty_scores[i]:.3f}). "
|
| 188 |
+
explanation += f"Nearest known species: {species_names[nearest] if species_names else f'Class {nearest}'} "
|
| 189 |
+
explanation += f"(distance: {min_distances[i]:.2f})"
|
| 190 |
+
else:
|
| 191 |
+
explanation = f"Sample matches known patterns (score: {novelty_scores[i]:.3f})"
|
| 192 |
+
|
| 193 |
+
results.append(NoveltyResult(
|
| 194 |
+
is_novel=is_novel,
|
| 195 |
+
novelty_score=float(novelty_scores[i]),
|
| 196 |
+
nearest_class=nearest,
|
| 197 |
+
nearest_distance=float(min_distances[i]),
|
| 198 |
+
confidence=float(1 - novelty_scores[i]),
|
| 199 |
+
explanation=explanation
|
| 200 |
+
))
|
| 201 |
+
|
| 202 |
+
return results
|
| 203 |
+
|
| 204 |
+
def save(self, path: str):
|
| 205 |
+
"""Save fitted detector to file."""
|
| 206 |
+
if not self.is_fitted:
|
| 207 |
+
raise RuntimeError("Detector not fitted.")
|
| 208 |
+
|
| 209 |
+
state = {
|
| 210 |
+
"embedding_dim": self.embedding_dim,
|
| 211 |
+
"num_classes": self.num_classes,
|
| 212 |
+
"threshold": self.threshold,
|
| 213 |
+
"class_means": self.class_means.numpy().tolist(),
|
| 214 |
+
"precision_matrix": self.precision_matrix.numpy().tolist(),
|
| 215 |
+
"global_covariance": self.global_covariance.numpy().tolist()
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
with open(path, 'w') as f:
|
| 219 |
+
json.dump(state, f)
|
| 220 |
+
|
| 221 |
+
def load(self, path: str):
|
| 222 |
+
"""Load fitted detector from file."""
|
| 223 |
+
with open(path, 'r') as f:
|
| 224 |
+
state = json.load(f)
|
| 225 |
+
|
| 226 |
+
self.embedding_dim = state["embedding_dim"]
|
| 227 |
+
self.num_classes = state["num_classes"]
|
| 228 |
+
self.threshold = state["threshold"]
|
| 229 |
+
self.class_means = torch.tensor(state["class_means"])
|
| 230 |
+
self.precision_matrix = torch.tensor(state["precision_matrix"])
|
| 231 |
+
self.global_covariance = torch.tensor(state["global_covariance"])
|
| 232 |
+
self.is_fitted = True
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class GeospatialNoveltyDetector(NoveltyDetector):
|
| 236 |
+
"""
|
| 237 |
+
Extended novelty detector with geospatial priors.
|
| 238 |
+
|
| 239 |
+
Considers species range maps to flag:
|
| 240 |
+
- Species identified outside their known range
|
| 241 |
+
- Unexpected seasonal occurrences
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
embedding_dim: int = 384,
|
| 247 |
+
num_classes: int = 250,
|
| 248 |
+
threshold: float = 0.85,
|
| 249 |
+
range_data_path: Optional[str] = None
|
| 250 |
+
):
|
| 251 |
+
super().__init__(embedding_dim, num_classes, threshold)
|
| 252 |
+
|
| 253 |
+
self.range_data: Dict[int, Dict] = {} # class_id -> range info
|
| 254 |
+
if range_data_path:
|
| 255 |
+
self._load_range_data(range_data_path)
|
| 256 |
+
|
| 257 |
+
def _load_range_data(self, path: str):
|
| 258 |
+
"""Load species range data."""
|
| 259 |
+
with open(path, 'r') as f:
|
| 260 |
+
self.range_data = json.load(f)
|
| 261 |
+
|
| 262 |
+
def check_range_novelty(
|
| 263 |
+
self,
|
| 264 |
+
class_idx: int,
|
| 265 |
+
latitude: float,
|
| 266 |
+
longitude: float,
|
| 267 |
+
month: Optional[int] = None
|
| 268 |
+
) -> Tuple[bool, str]:
|
| 269 |
+
"""
|
| 270 |
+
Check if species occurrence is novel given location.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
class_idx: Predicted species index
|
| 274 |
+
latitude: Recording latitude
|
| 275 |
+
longitude: Recording longitude
|
| 276 |
+
month: Optional month for seasonal check
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Tuple of (is_range_novel, explanation)
|
| 280 |
+
"""
|
| 281 |
+
if class_idx not in self.range_data:
|
| 282 |
+
return False, "No range data available"
|
| 283 |
+
|
| 284 |
+
range_info = self.range_data[class_idx]
|
| 285 |
+
|
| 286 |
+
# Simple bounding box check (can be enhanced with actual range polygons)
|
| 287 |
+
lat_min = range_info.get("lat_min", -90)
|
| 288 |
+
lat_max = range_info.get("lat_max", 90)
|
| 289 |
+
lon_min = range_info.get("lon_min", -180)
|
| 290 |
+
lon_max = range_info.get("lon_max", 180)
|
| 291 |
+
|
| 292 |
+
in_range = (lat_min <= latitude <= lat_max and
|
| 293 |
+
lon_min <= longitude <= lon_max)
|
| 294 |
+
|
| 295 |
+
if not in_range:
|
| 296 |
+
return True, f"Species rarely found at this location ({latitude:.2f}, {longitude:.2f})"
|
| 297 |
+
|
| 298 |
+
# Seasonal check
|
| 299 |
+
if month and "seasonal_months" in range_info:
|
| 300 |
+
if month not in range_info["seasonal_months"]:
|
| 301 |
+
return True, f"Species unusual for month {month}"
|
| 302 |
+
|
| 303 |
+
return False, "Within expected range"
|
| 304 |
+
|
| 305 |
+
def detect_with_location(
|
| 306 |
+
self,
|
| 307 |
+
embeddings: torch.Tensor,
|
| 308 |
+
predicted_class: torch.Tensor,
|
| 309 |
+
latitude: float,
|
| 310 |
+
longitude: float,
|
| 311 |
+
month: Optional[int] = None,
|
| 312 |
+
species_names: Optional[List[str]] = None
|
| 313 |
+
) -> List[NoveltyResult]:
|
| 314 |
+
"""
|
| 315 |
+
Detect novelty considering both embeddings and location.
|
| 316 |
+
"""
|
| 317 |
+
# Get embedding-based results
|
| 318 |
+
results = self.detect(embeddings, predicted_class, species_names)
|
| 319 |
+
|
| 320 |
+
# Enhance with geospatial information
|
| 321 |
+
for i, result in enumerate(results):
|
| 322 |
+
pred_class = predicted_class[i].item()
|
| 323 |
+
is_range_novel, range_explanation = self.check_range_novelty(
|
| 324 |
+
pred_class, latitude, longitude, month
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if is_range_novel:
|
| 328 |
+
# Boost novelty score for out-of-range detections
|
| 329 |
+
result.novelty_score = min(1.0, result.novelty_score + 0.3)
|
| 330 |
+
result.is_novel = result.novelty_score > self.threshold
|
| 331 |
+
result.explanation += f" | RANGE ALERT: {range_explanation}"
|
| 332 |
+
|
| 333 |
+
return results
|
| 334 |
+
|
requirements.txt
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BirdSense Pro - HuggingFace Space
|
| 2 |
+
# Minimal requirements for reliable deployment
|
| 3 |
+
|
| 4 |
+
gradio==4.19.0
|
| 5 |
+
numpy>=1.21.0
|
| 6 |
+
scipy>=1.7.0
|
| 7 |
+
requests>=2.28.0
|
| 8 |
+
Pillow>=9.0.0
|