#!/usr/bin/env python3 """ Gradio app for TimesNet-Gen: Generate seismic samples from latent bank. Based on generate_samples_git.py (working GitHub version). NO PLOTTING - only NPZ generation and display in Gradio interface. Model files are automatically downloaded from Hugging Face Hub. """ import os import gradio as gr import torch import numpy as np from datetime import datetime import matplotlib.pyplot as plt from io import BytesIO from PIL import Image from huggingface_hub import hf_hub_download # Hugging Face model repository HF_REPO_ID = "Barisylmz/TimesNet-Gen-Models" # Your HF model repo CHECKPOINT_FILENAME = "timesnet_pointcloud_phase1_final.pth" LATENT_BANK_FILENAME = "latent_bank_station_cond.npz" # Actual filename on HF Hub def download_model_files(): """ Download model files from Hugging Face Hub if not present locally. Returns: checkpoint_path: Path to downloaded checkpoint latent_bank_path: Path to downloaded latent bank """ print("[INFO] Checking for model files...") # Check if files exist locally first if os.path.exists(CHECKPOINT_FILENAME) and os.path.exists(LATENT_BANK_FILENAME): print("[INFO] Model files found locally") return CHECKPOINT_FILENAME, LATENT_BANK_FILENAME print(f"[INFO] Downloading model files from Hugging Face Hub: {HF_REPO_ID}") try: # Download checkpoint if not os.path.exists(CHECKPOINT_FILENAME): print(f"[INFO] Downloading {CHECKPOINT_FILENAME}...") checkpoint_path = hf_hub_download( repo_id=HF_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir="./hf_cache" ) print(f"[INFO] ✅ Checkpoint downloaded to {checkpoint_path}") else: checkpoint_path = CHECKPOINT_FILENAME # Download latent bank if not os.path.exists(LATENT_BANK_FILENAME): print(f"[INFO] Downloading {LATENT_BANK_FILENAME}...") latent_bank_path = hf_hub_download( repo_id=HF_REPO_ID, filename=LATENT_BANK_FILENAME, cache_dir="./hf_cache" ) print(f"[INFO] ✅ Latent bank downloaded to {latent_bank_path}") else: latent_bank_path = LATENT_BANK_FILENAME return checkpoint_path, latent_bank_path except Exception as e: print(f"[ERROR] Failed to download model files: {e}") print(f"[INFO] Please ensure files exist at: https://huggingface.co/{HF_REPO_ID}") raise class SimpleArgs: """Configuration for generation (matching GitHub version).""" def __init__(self): # Model architecture self.seq_len = 6000 self.d_model = 128 self.d_ff = 256 self.e_layers = 2 self.d_layers = 2 self.num_kernels = 6 self.top_k = 2 self.dropout = 0.1 self.latent_dim = 256 # System self.use_gpu = torch.cuda.is_available() self.seed = 0 # Point-cloud generation self.pcgen_k = 5 self.pcgen_jitter_std = 0.0 def load_model(checkpoint_path, args): """Load pre-trained TimesNet-PointCloud model (matching GitHub version).""" from TimesNet_PointCloud import TimesNetPointCloud # Create model config (NO num_stations - GitHub version doesn't use it) class ModelConfig: def __init__(self, args): self.seq_len = args.seq_len self.pred_len = 0 self.enc_in = 3 self.c_out = 3 self.d_model = args.d_model self.d_ff = args.d_ff self.num_kernels = args.num_kernels self.top_k = args.top_k self.e_layers = args.e_layers self.d_layers = args.d_layers self.dropout = args.dropout self.embed = 'timeF' self.freq = 'h' self.latent_dim = args.latent_dim config = ModelConfig(args) model = TimesNetPointCloud(config) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.eval() if args.use_gpu: model = model.cuda() print(f"[INFO] Model loaded successfully from {checkpoint_path}") return model def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args): """ Generate samples directly from pre-computed latent bank (matching GitHub version). Args: model: TimesNet model latent_bank_path: Path to latent_bank_phase1.npz station_id: Station ID (e.g., '0205') num_samples: Number of samples to generate args: Model arguments Returns: generated_signals: (num_samples, 3, seq_len) array real_names_used: List of lists indicating which latent vectors were used """ print(f"[INFO] Loading latent bank from {latent_bank_path}...") try: latent_data = np.load(latent_bank_path) except Exception as e: print(f"[ERROR] Could not load latent bank: {e}") return None, None # Load latent vectors for this station latents_key = f'latents_{station_id}' means_key = f'means_{station_id}' stdev_key = f'stdev_{station_id}' if latents_key not in latent_data: print(f"[ERROR] Station {station_id} not found in latent bank!") available = [k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')] print(f"Available stations: {available}") return None, None latents = latent_data[latents_key] # (N_samples, seq_len, d_model) means = latent_data[means_key] # (N_samples, seq_len, d_model) stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model) print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}") print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...") generated_signals = [] real_names_used = [] model.eval() with torch.no_grad(): for i in range(num_samples): # Bootstrap: randomly select k latent vectors with replacement k = min(args.pcgen_k, len(latents)) selected_indices = np.random.choice(len(latents), size=k, replace=True) # Mix latent features (average) selected_latents = latents[selected_indices] # (k, seq_len, d_model) selected_means = means[selected_indices] # (k, seq_len, d_model) selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model) mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model) mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model) mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model) # Convert to torch tensors mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model) means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model) stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model) if args.use_gpu: mixed_features_torch = mixed_features_torch.cuda() means_b = means_b.cuda() stdev_b = stdev_b.cuda() # Decode xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b) # Store - transpose to (3, 6000) generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) → (3, 6000) generated_signals.append(generated_np) # Track which latent indices were used real_names_used.append([f"latent_{idx}" for idx in selected_indices]) if (i + 1) % 10 == 0: print(f" Generated {i + 1}/{num_samples} samples...") return np.array(generated_signals), real_names_used def save_generated_samples(generated_signals, real_names, station_id, output_dir): """Save generated samples to NPZ file (NO PLOTTING).""" os.makedirs(output_dir, exist_ok=True) # Save timeseries NPZ output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz') np.savez_compressed( output_path, generated_signals=generated_signals, signals_generated=generated_signals, # Alias for compatibility real_names=real_names, station_id=station_id, station=station_id, # Alias for compatibility ) print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}") return output_path def plot_signal_for_display(signal): """ Plot a single 3-component seismic signal for Gradio display. Args: signal: (3, 6000) array [E, N, Z] Returns: PIL Image """ fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) component_names = ['East', 'North', 'Vertical'] colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] t = np.arange(signal.shape[1]) / 100.0 # 100 Hz sampling for idx, (ax, name, color) in enumerate(zip(axes, component_names, colors)): ax.plot(t, signal[idx], color=color, linewidth=0.5, alpha=0.8) ax.set_ylabel(f'{name}\n(cm/s²)', fontsize=10) ax.grid(True, alpha=0.3) ax.set_xlim(0, 60) axes[-1].set_xlabel('Time (s)', fontsize=11) fig.suptitle('Generated Seismic Signal (3 Components)', fontsize=13, fontweight='bold') plt.tight_layout() # Convert to PIL Image buf = BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') plt.close(fig) buf.seek(0) return Image.open(buf) # =========================== # Gradio Interface Functions # =========================== def generate_samples_interface(station_id, num_samples, progress=gr.Progress()): """ Generate samples for Gradio interface. Args: station_id: Station ID (e.g., '0205') num_samples: Number of samples to generate progress: Gradio progress tracker Returns: status_message: Generation status npz_path: Path to saved NPZ file sample_plot: Preview plot of first generated sample """ try: # Convert num_samples to int (Gradio might pass float) num_samples = int(num_samples) print(f"[DEBUG] Requested num_samples: {num_samples} (type: {type(num_samples)})") progress(0, desc="Checking model files...") # Download model files from HF Hub if needed try: checkpoint_path, latent_bank_path = download_model_files() except Exception as e: error_msg = f"❌ Error downloading model files:\n{str(e)}\n\n" error_msg += "Please ensure model files are uploaded to:\n" error_msg += f"https://huggingface.co/{HF_REPO_ID}" return error_msg, None, None output_dir = 'generated_outputs' progress(0.1, desc="Loading model...") # Load model args = SimpleArgs() model = load_model(checkpoint_path, args) progress(0.3, desc=f"Generating {num_samples} samples...") # Generate samples generated_signals, real_names = generate_samples_from_latent_bank( model, latent_bank_path, station_id, num_samples, args ) if generated_signals is None: return f"❌ Error: Failed to generate samples for station {station_id}", None, None progress(0.8, desc="Saving NPZ file...") # Save NPZ (NO PLOTTING) npz_path = save_generated_samples(generated_signals, real_names, station_id, output_dir) progress(0.95, desc="Creating preview plot...") # Create preview plot for first sample sample_plot = plot_signal_for_display(generated_signals[0]) progress(1.0, desc="Done!") # Verify actual number of samples generated actual_count = len(generated_signals) status_msg = f"✅ Successfully generated {actual_count} samples for station {station_id}!\n" status_msg += f"📊 Requested: {num_samples} samples\n" status_msg += f"📁 Saved to: {npz_path}\n" status_msg += f"📈 Preview of first generated sample shown below." return status_msg, npz_path, sample_plot except Exception as e: import traceback error_msg = f"❌ Error during generation:\n{str(e)}\n\n{traceback.format_exc()}" return error_msg, None, None def load_and_display_npz(npz_file, sample_idx): """ Load NPZ file and display a specific sample. Args: npz_file: Path to NPZ file sample_idx: Index of sample to display (0-based) Returns: status_message: Load status sample_plot: Plot of selected sample """ try: if npz_file is None: return "⚠️ No NPZ file provided", None # Load NPZ data = np.load(npz_file) generated_signals = data['generated_signals'] if sample_idx < 0 or sample_idx >= len(generated_signals): return f"⚠️ Sample index {sample_idx} out of range (0-{len(generated_signals)-1})", None # Plot selected sample sample_plot = plot_signal_for_display(generated_signals[sample_idx]) status_msg = f"✅ Loaded NPZ with {len(generated_signals)} samples\n" status_msg += f"📊 Displaying sample #{sample_idx}" return status_msg, sample_plot except Exception as e: import traceback error_msg = f"❌ Error loading NPZ:\n{str(e)}\n\n{traceback.format_exc()}" return error_msg, None # =========================== # Gradio App # =========================== def create_demo(): """Create Gradio interface.""" with gr.Blocks(title="TimesNet-Gen Demo: Station-Specific Seismic Generator") as demo: gr.Markdown(""" # 🌍 TimesNet-Gen Demo: Station-Specific Seismic Sample Generator For more detailed information: https://arxiv.org/abs/2512.04694 Generate realistic synthetic seismic signals with station-specific characteristics. **Instructions:** 1. Select a station ID (5 fine-tuned stations available) 2. Choose number of samples to generate 3. Click "Generate Samples" and wait 4. Preview generated samples or download NPZ file """) with gr.Tab("Generate Samples"): with gr.Row(): with gr.Column(scale=1): station_dropdown = gr.Dropdown( choices=['0205', '1716', '2020', '3130', '4628'], value='0205', label="Station ID", info="Select target station" ) num_samples_slider = gr.Slider( minimum=1, maximum=200, value=50, step=1, label="Number of Samples", info="How many samples to generate" ) generate_btn = gr.Button("🚀 Generate Samples", variant="primary") with gr.Column(scale=2): status_text = gr.Textbox( label="Status", lines=5, interactive=False ) npz_file_output = gr.File( label="Generated NPZ File", interactive=False ) gr.Markdown("### Preview (First Generated Sample)") preview_plot = gr.Image(label="Sample Preview") generate_btn.click( fn=generate_samples_interface, inputs=[station_dropdown, num_samples_slider], outputs=[status_text, npz_file_output, preview_plot] ) with gr.Tab("View Saved Samples"): gr.Markdown("### Load and view samples from saved NPZ file") with gr.Row(): with gr.Column(scale=1): npz_upload = gr.File( label="Upload NPZ File", file_types=['.npz'] ) sample_idx_slider = gr.Slider( minimum=0, maximum=199, value=0, step=1, label="Sample Index", info="Which sample to display (0-based)" ) load_btn = gr.Button("📊 Load and Display", variant="secondary") with gr.Column(scale=2): load_status = gr.Textbox( label="Status", lines=3, interactive=False ) display_plot = gr.Image(label="Sample Display") load_btn.click( fn=load_and_display_npz, inputs=[npz_upload, sample_idx_slider], outputs=[load_status, display_plot] ) gr.Markdown(""" --- **Model:** TimesNet-PointCloud **Method:** Bootstrap aggregation from latent bank **Stations:** 5 fine-tuned Turkish strong-motion stations **Output:** 3-component acceleration signals (E, N, Z) @ 100 Hz """) return demo if __name__ == "__main__": demo = create_demo() demo.launch(server_name="0.0.0.0", server_port=7860)