Barisylmz commited on
Commit
2535d2b
Β·
verified Β·
1 Parent(s): 9992e6b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +436 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app for TimesNet-Gen: Generate seismic samples from latent bank.
4
+ Based on generate_samples_git.py (working GitHub version).
5
+
6
+ NO PLOTTING - only NPZ generation and display in Gradio interface.
7
+ """
8
+ import os
9
+ import gradio as gr
10
+ import torch
11
+ import numpy as np
12
+ from datetime import datetime
13
+ import matplotlib.pyplot as plt
14
+ from io import BytesIO
15
+ from PIL import Image
16
+
17
+
18
+ class SimpleArgs:
19
+ """Configuration for generation (matching GitHub version)."""
20
+ def __init__(self):
21
+ # Model architecture
22
+ self.seq_len = 6000
23
+ self.d_model = 128
24
+ self.d_ff = 256
25
+ self.e_layers = 2
26
+ self.d_layers = 2
27
+ self.num_kernels = 6
28
+ self.top_k = 2
29
+ self.dropout = 0.1
30
+ self.latent_dim = 256
31
+
32
+ # System
33
+ self.use_gpu = torch.cuda.is_available()
34
+ self.seed = 0
35
+
36
+ # Point-cloud generation
37
+ self.pcgen_k = 5
38
+ self.pcgen_jitter_std = 0.0
39
+
40
+
41
+ def load_model(checkpoint_path, args):
42
+ """Load pre-trained TimesNet-PointCloud model (matching GitHub version)."""
43
+ from TimesNet_PointCloud import TimesNetPointCloud
44
+
45
+ # Create model config (NO num_stations - GitHub version doesn't use it)
46
+ class ModelConfig:
47
+ def __init__(self, args):
48
+ self.seq_len = args.seq_len
49
+ self.pred_len = 0
50
+ self.enc_in = 3
51
+ self.c_out = 3
52
+ self.d_model = args.d_model
53
+ self.d_ff = args.d_ff
54
+ self.num_kernels = args.num_kernels
55
+ self.top_k = args.top_k
56
+ self.e_layers = args.e_layers
57
+ self.d_layers = args.d_layers
58
+ self.dropout = args.dropout
59
+ self.embed = 'timeF'
60
+ self.freq = 'h'
61
+ self.latent_dim = args.latent_dim
62
+
63
+ config = ModelConfig(args)
64
+ model = TimesNetPointCloud(config)
65
+
66
+ # Load checkpoint
67
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
68
+ if 'model_state_dict' in checkpoint:
69
+ model.load_state_dict(checkpoint['model_state_dict'])
70
+ else:
71
+ model.load_state_dict(checkpoint)
72
+
73
+ model.eval()
74
+ if args.use_gpu:
75
+ model = model.cuda()
76
+
77
+ print(f"[INFO] Model loaded successfully from {checkpoint_path}")
78
+ return model
79
+
80
+
81
+ def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args):
82
+ """
83
+ Generate samples directly from pre-computed latent bank (matching GitHub version).
84
+
85
+ Args:
86
+ model: TimesNet model
87
+ latent_bank_path: Path to latent_bank_phase1.npz
88
+ station_id: Station ID (e.g., '0205')
89
+ num_samples: Number of samples to generate
90
+ args: Model arguments
91
+
92
+ Returns:
93
+ generated_signals: (num_samples, 3, seq_len) array
94
+ real_names_used: List of lists indicating which latent vectors were used
95
+ """
96
+ print(f"[INFO] Loading latent bank from {latent_bank_path}...")
97
+
98
+ try:
99
+ latent_data = np.load(latent_bank_path)
100
+ except Exception as e:
101
+ print(f"[ERROR] Could not load latent bank: {e}")
102
+ return None, None
103
+
104
+ # Load latent vectors for this station
105
+ latents_key = f'latents_{station_id}'
106
+ means_key = f'means_{station_id}'
107
+ stdev_key = f'stdev_{station_id}'
108
+
109
+ if latents_key not in latent_data:
110
+ print(f"[ERROR] Station {station_id} not found in latent bank!")
111
+ available = [k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')]
112
+ print(f"Available stations: {available}")
113
+ return None, None
114
+
115
+ latents = latent_data[latents_key] # (N_samples, seq_len, d_model)
116
+ means = latent_data[means_key] # (N_samples, seq_len, d_model)
117
+ stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model)
118
+
119
+ print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}")
120
+ print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...")
121
+
122
+ generated_signals = []
123
+ real_names_used = []
124
+
125
+ model.eval()
126
+ with torch.no_grad():
127
+ for i in range(num_samples):
128
+ # Bootstrap: randomly select k latent vectors with replacement
129
+ k = min(args.pcgen_k, len(latents))
130
+ selected_indices = np.random.choice(len(latents), size=k, replace=True)
131
+
132
+ # Mix latent features (average)
133
+ selected_latents = latents[selected_indices] # (k, seq_len, d_model)
134
+ selected_means = means[selected_indices] # (k, seq_len, d_model)
135
+ selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model)
136
+
137
+ mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model)
138
+ mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model)
139
+ mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model)
140
+
141
+ # Convert to torch tensors
142
+ mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model)
143
+ means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model)
144
+ stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model)
145
+
146
+ if args.use_gpu:
147
+ mixed_features_torch = mixed_features_torch.cuda()
148
+ means_b = means_b.cuda()
149
+ stdev_b = stdev_b.cuda()
150
+
151
+ # Decode
152
+ xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b)
153
+
154
+ # Store - transpose to (3, 6000)
155
+ generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) β†’ (3, 6000)
156
+ generated_signals.append(generated_np)
157
+
158
+ # Track which latent indices were used
159
+ real_names_used.append([f"latent_{idx}" for idx in selected_indices])
160
+
161
+ if (i + 1) % 10 == 0:
162
+ print(f" Generated {i + 1}/{num_samples} samples...")
163
+
164
+ return np.array(generated_signals), real_names_used
165
+
166
+
167
+ def save_generated_samples(generated_signals, real_names, station_id, output_dir):
168
+ """Save generated samples to NPZ file (NO PLOTTING)."""
169
+ os.makedirs(output_dir, exist_ok=True)
170
+
171
+ # Save timeseries NPZ
172
+ output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz')
173
+ np.savez_compressed(
174
+ output_path,
175
+ generated_signals=generated_signals,
176
+ signals_generated=generated_signals, # Alias for compatibility
177
+ real_names=real_names,
178
+ station_id=station_id,
179
+ station=station_id, # Alias for compatibility
180
+ )
181
+
182
+ print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}")
183
+ return output_path
184
+
185
+
186
+ def plot_signal_for_display(signal):
187
+ """
188
+ Plot a single 3-component seismic signal for Gradio display.
189
+
190
+ Args:
191
+ signal: (3, 6000) array [E, N, Z]
192
+
193
+ Returns:
194
+ PIL Image
195
+ """
196
+ fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
197
+
198
+ component_names = ['East', 'North', 'Vertical']
199
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
200
+
201
+ t = np.arange(signal.shape[1]) / 100.0 # 100 Hz sampling
202
+
203
+ for idx, (ax, name, color) in enumerate(zip(axes, component_names, colors)):
204
+ ax.plot(t, signal[idx], color=color, linewidth=0.5, alpha=0.8)
205
+ ax.set_ylabel(f'{name}\n(cm/sΒ²)', fontsize=10)
206
+ ax.grid(True, alpha=0.3)
207
+ ax.set_xlim(0, 60)
208
+
209
+ axes[-1].set_xlabel('Time (s)', fontsize=11)
210
+ fig.suptitle('Generated Seismic Signal (3 Components)', fontsize=13, fontweight='bold')
211
+ plt.tight_layout()
212
+
213
+ # Convert to PIL Image
214
+ buf = BytesIO()
215
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
216
+ plt.close(fig)
217
+ buf.seek(0)
218
+ return Image.open(buf)
219
+
220
+
221
+ # ===========================
222
+ # Gradio Interface Functions
223
+ # ===========================
224
+
225
+ def generate_samples_interface(station_id, num_samples, progress=gr.Progress()):
226
+ """
227
+ Generate samples for Gradio interface.
228
+
229
+ Args:
230
+ station_id: Station ID (e.g., '0205')
231
+ num_samples: Number of samples to generate
232
+ progress: Gradio progress tracker
233
+
234
+ Returns:
235
+ status_message: Generation status
236
+ npz_path: Path to saved NPZ file
237
+ sample_plot: Preview plot of first generated sample
238
+ """
239
+ try:
240
+ progress(0, desc="Initializing...")
241
+
242
+ # Paths
243
+ checkpoint_path = 'timesnet_pointcloud_phase1_final.pth'
244
+ latent_bank_path = 'latent_bank_phase1.npz'
245
+ output_dir = 'generated_outputs'
246
+
247
+ # Check files exist
248
+ if not os.path.exists(checkpoint_path):
249
+ return f"❌ Error: Checkpoint not found at {checkpoint_path}", None, None
250
+ if not os.path.exists(latent_bank_path):
251
+ return f"❌ Error: Latent bank not found at {latent_bank_path}", None, None
252
+
253
+ progress(0.1, desc="Loading model...")
254
+
255
+ # Load model
256
+ args = SimpleArgs()
257
+ model = load_model(checkpoint_path, args)
258
+
259
+ progress(0.3, desc=f"Generating {num_samples} samples...")
260
+
261
+ # Generate samples
262
+ generated_signals, real_names = generate_samples_from_latent_bank(
263
+ model, latent_bank_path, station_id, num_samples, args
264
+ )
265
+
266
+ if generated_signals is None:
267
+ return f"❌ Error: Failed to generate samples for station {station_id}", None, None
268
+
269
+ progress(0.8, desc="Saving NPZ file...")
270
+
271
+ # Save NPZ (NO PLOTTING)
272
+ npz_path = save_generated_samples(generated_signals, real_names, station_id, output_dir)
273
+
274
+ progress(0.95, desc="Creating preview plot...")
275
+
276
+ # Create preview plot for first sample
277
+ sample_plot = plot_signal_for_display(generated_signals[0])
278
+
279
+ progress(1.0, desc="Done!")
280
+
281
+ status_msg = f"βœ… Successfully generated {num_samples} samples for station {station_id}!\n"
282
+ status_msg += f"πŸ“ Saved to: {npz_path}\n"
283
+ status_msg += f"πŸ“Š Preview of first generated sample shown below."
284
+
285
+ return status_msg, npz_path, sample_plot
286
+
287
+ except Exception as e:
288
+ import traceback
289
+ error_msg = f"❌ Error during generation:\n{str(e)}\n\n{traceback.format_exc()}"
290
+ return error_msg, None, None
291
+
292
+
293
+ def load_and_display_npz(npz_file, sample_idx):
294
+ """
295
+ Load NPZ file and display a specific sample.
296
+
297
+ Args:
298
+ npz_file: Path to NPZ file
299
+ sample_idx: Index of sample to display (0-based)
300
+
301
+ Returns:
302
+ status_message: Load status
303
+ sample_plot: Plot of selected sample
304
+ """
305
+ try:
306
+ if npz_file is None:
307
+ return "⚠️ No NPZ file provided", None
308
+
309
+ # Load NPZ
310
+ data = np.load(npz_file)
311
+ generated_signals = data['generated_signals']
312
+
313
+ if sample_idx < 0 or sample_idx >= len(generated_signals):
314
+ return f"⚠️ Sample index {sample_idx} out of range (0-{len(generated_signals)-1})", None
315
+
316
+ # Plot selected sample
317
+ sample_plot = plot_signal_for_display(generated_signals[sample_idx])
318
+
319
+ status_msg = f"βœ… Loaded NPZ with {len(generated_signals)} samples\n"
320
+ status_msg += f"πŸ“Š Displaying sample #{sample_idx}"
321
+
322
+ return status_msg, sample_plot
323
+
324
+ except Exception as e:
325
+ import traceback
326
+ error_msg = f"❌ Error loading NPZ:\n{str(e)}\n\n{traceback.format_exc()}"
327
+ return error_msg, None
328
+
329
+
330
+ # ===========================
331
+ # Gradio App
332
+ # ===========================
333
+
334
+ def create_demo():
335
+ """Create Gradio interface."""
336
+
337
+ with gr.Blocks(title="TimesNet-Gen: Seismic Sample Generator") as demo:
338
+ gr.Markdown("""
339
+ # 🌍 TimesNet-Gen: Seismic Sample Generator
340
+
341
+ Generate synthetic seismic signals from pre-trained latent bank.
342
+
343
+ **Instructions:**
344
+ 1. Select a station ID (5 fine-tuned stations available)
345
+ 2. Choose number of samples to generate
346
+ 3. Click "Generate Samples" and wait
347
+ 4. Preview generated samples or download NPZ file
348
+ """)
349
+
350
+ with gr.Tab("Generate Samples"):
351
+ with gr.Row():
352
+ with gr.Column(scale=1):
353
+ station_dropdown = gr.Dropdown(
354
+ choices=['0205', '1716', '2020', '3130', '4628'],
355
+ value='0205',
356
+ label="Station ID",
357
+ info="Select target station"
358
+ )
359
+ num_samples_slider = gr.Slider(
360
+ minimum=1,
361
+ maximum=200,
362
+ value=50,
363
+ step=1,
364
+ label="Number of Samples",
365
+ info="How many samples to generate"
366
+ )
367
+ generate_btn = gr.Button("πŸš€ Generate Samples", variant="primary")
368
+
369
+ with gr.Column(scale=2):
370
+ status_text = gr.Textbox(
371
+ label="Status",
372
+ lines=5,
373
+ interactive=False
374
+ )
375
+ npz_file_output = gr.File(
376
+ label="Generated NPZ File",
377
+ interactive=False
378
+ )
379
+
380
+ gr.Markdown("### Preview (First Generated Sample)")
381
+ preview_plot = gr.Image(label="Sample Preview")
382
+
383
+ generate_btn.click(
384
+ fn=generate_samples_interface,
385
+ inputs=[station_dropdown, num_samples_slider],
386
+ outputs=[status_text, npz_file_output, preview_plot]
387
+ )
388
+
389
+ with gr.Tab("View Saved Samples"):
390
+ gr.Markdown("### Load and view samples from saved NPZ file")
391
+
392
+ with gr.Row():
393
+ with gr.Column(scale=1):
394
+ npz_upload = gr.File(
395
+ label="Upload NPZ File",
396
+ file_types=['.npz']
397
+ )
398
+ sample_idx_slider = gr.Slider(
399
+ minimum=0,
400
+ maximum=199,
401
+ value=0,
402
+ step=1,
403
+ label="Sample Index",
404
+ info="Which sample to display (0-based)"
405
+ )
406
+ load_btn = gr.Button("πŸ“Š Load and Display", variant="secondary")
407
+
408
+ with gr.Column(scale=2):
409
+ load_status = gr.Textbox(
410
+ label="Status",
411
+ lines=3,
412
+ interactive=False
413
+ )
414
+
415
+ display_plot = gr.Image(label="Sample Display")
416
+
417
+ load_btn.click(
418
+ fn=load_and_display_npz,
419
+ inputs=[npz_upload, sample_idx_slider],
420
+ outputs=[load_status, display_plot]
421
+ )
422
+
423
+ gr.Markdown("""
424
+ ---
425
+ **Model:** TimesNet-PointCloud
426
+ **Method:** Bootstrap aggregation from latent bank
427
+ **Stations:** 5 fine-tuned Turkish strong-motion stations
428
+ **Output:** 3-component acceleration signals (E, N, Z) @ 100 Hz
429
+ """)
430
+
431
+ return demo
432
+
433
+
434
+ if __name__ == "__main__":
435
+ demo = create_demo()
436
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ numpy>=1.24.0
4
+ scipy>=1.10.0
5
+ matplotlib>=3.7.0
6
+ Pillow>=9.0.0
7
+ huggingface_hub>=0.20.0
8
+