TimesNet-Gen / generate_samples_git.py
Barisylmz's picture
Upload 4 files
0dfdc08 verified
raw
history blame
33.6 kB
#!/usr/bin/env python3
"""
Simplified inference script for TimesNet-Gen.
Only loads data for the 5 fine-tuned stations.
Usage:
python generate_samples.py --num_samples 50
"""
import os
import argparse
import torch
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
import glob
import scipy.io as sio
class SimpleArgs:
"""Configuration for generation."""
def __init__(self):
# Model architecture
self.seq_len = 6000
self.d_model = 128
self.d_ff = 256
self.e_layers = 2
self.d_layers = 2
self.num_kernels = 6
self.top_k = 2
self.dropout = 0.1
self.latent_dim = 256
# System
self.use_gpu = torch.cuda.is_available()
self.seed = 0
# Point-cloud generation
self.pcgen_k = 5
self.pcgen_jitter_std = 0.0
def _iter_np_arrays(obj):
"""Recursively iterate through numpy arrays in nested structures."""
if isinstance(obj, np.ndarray):
if obj.dtype == object:
for item in obj.flat:
yield from _iter_np_arrays(item)
else:
yield obj
elif isinstance(obj, dict):
for v in obj.values():
yield from _iter_np_arrays(v)
elif isinstance(obj, np.void):
if obj.dtype.names:
for name in obj.dtype.names:
yield from _iter_np_arrays(obj[name])
def _find_3ch_from_arrays(arrays):
"""Find 3-channel array from list of arrays."""
# Prefer arrays that are 2D with a 3-channel dimension
for arr in arrays:
if isinstance(arr, np.ndarray) and arr.ndim == 2 and (arr.shape[0] == 3 or arr.shape[1] == 3):
return arr
# Otherwise, try to find three 1D arrays of same length
one_d = [a for a in arrays if isinstance(a, np.ndarray) and a.ndim == 1]
for i in range(len(one_d)):
for j in range(i + 1, len(one_d)):
for k in range(j + 1, len(one_d)):
if one_d[i].shape[0] == one_d[j].shape[0] == one_d[k].shape[0]:
return np.stack([one_d[i], one_d[j], one_d[k]], axis=0)
return None
def load_mat_file(filepath, seq_len=6000, debug=False):
"""Load and preprocess a .mat file (using data_loader_gen.py logic)."""
try:
if debug:
print(f"\n[DEBUG] Loading: {os.path.basename(filepath)}")
# Load with squeeze_me and struct_as_record like data_loader_gen.py
mat = sio.loadmat(filepath, squeeze_me=True, struct_as_record=False)
if debug:
print(f"[DEBUG] Keys in mat file: {[k for k in mat.keys() if not k.startswith('__')]}")
# Check if 'EQ' is a struct with nested 'anEQ' structure (like in data_loader_gen.py)
if 'EQ' in mat:
try:
eq_obj = mat['EQ']
if debug:
print(f"[DEBUG] EQ type: {type(eq_obj)}")
print(f"[DEBUG] EQ shape: {eq_obj.shape if hasattr(eq_obj, 'shape') else 'N/A'}")
# Since struct_as_record=False, EQ is a mat_struct object
# Access with attributes, not subscripts
if hasattr(eq_obj, 'anEQ'):
dataset = eq_obj.anEQ
if debug:
print(f"[DEBUG] Found anEQ, type: {type(dataset)}")
if hasattr(dataset, 'Accel'):
accel = dataset.Accel
if debug:
print(f"[DEBUG] Found Accel: type={type(accel)}, shape={accel.shape if hasattr(accel, 'shape') else 'N/A'}")
if isinstance(accel, np.ndarray):
# Transpose to (3, N) if needed
if accel.ndim == 2:
if accel.shape[1] == 3:
accel = accel.T
if accel.shape[0] == 3:
data = accel
if debug:
print(f"[DEBUG] ✅ Successfully extracted 3-channel data! Shape: {data.shape}")
# Resample if needed
if data.shape[1] != seq_len:
from scipy import signal as sp_signal
data_resampled = np.zeros((3, seq_len), dtype=np.float32)
for i in range(3):
data_resampled[i] = sp_signal.resample(data[i], seq_len)
data = data_resampled
if debug:
print(f"[DEBUG] Resampled to {seq_len} samples")
return torch.FloatTensor(data)
else:
if debug:
print(f"[DEBUG] Unexpected Accel shape[0]: {accel.shape[0]} (expected 3)")
else:
if debug:
print(f"[DEBUG] Accel is not 2D: ndim={accel.ndim}")
else:
if debug:
print(f"[DEBUG] anEQ has no 'Accel' attribute")
if hasattr(dataset, '__dict__'):
print(f"[DEBUG] anEQ attributes: {list(vars(dataset).keys())}")
else:
if debug:
print(f"[DEBUG] EQ has no 'anEQ' attribute")
if hasattr(eq_obj, '__dict__'):
print(f"[DEBUG] EQ attributes: {list(vars(eq_obj).keys())}")
except Exception as e:
if debug:
import traceback
print(f"[DEBUG] Could not parse EQ structure: {e}")
print(f"[DEBUG] Traceback: {traceback.format_exc()}")
arrays = list(_iter_np_arrays(mat))
if debug:
print(f"[DEBUG] Found {len(arrays)} arrays")
for i, arr in enumerate(arrays[:5]): # Show first 5
if isinstance(arr, np.ndarray):
print(f"[DEBUG] Array {i}: shape={arr.shape}, dtype={arr.dtype}")
# Common direct keys first
for key in ['signal', 'data', 'sig', 'x', 'X', 'signal3c', 'acc', 'NS', 'EW', 'UD']:
if key in mat and isinstance(mat[key], np.ndarray):
arrays.insert(0, mat[key])
if debug:
print(f"[DEBUG] Found key '{key}': shape={mat[key].shape}")
# Find 3-channel array
data = _find_3ch_from_arrays(arrays)
if data is None:
if debug:
print(f"[DEBUG] Could not find 3-channel array!")
return None
if debug:
print(f"[DEBUG] Found 3-channel data: shape={data.shape}")
# Ensure shape is (3, N)
if data.shape[0] != 3 and data.shape[1] == 3:
data = data.T
if debug:
print(f"[DEBUG] Transposed to: shape={data.shape}")
if data.shape[0] != 3:
if debug:
print(f"[DEBUG] Wrong number of channels: {data.shape[0]}")
return None
# Resample to seq_len
if data.shape[1] != seq_len:
from scipy import signal as sp_signal
data_resampled = np.zeros((3, seq_len), dtype=np.float32)
for i in range(3):
data_resampled[i] = sp_signal.resample(data[i], seq_len)
data = data_resampled
if debug:
print(f"[DEBUG] Resampled to: shape={data.shape}")
if debug:
print(f"[DEBUG] ✅ Successfully loaded!")
return torch.FloatTensor(data)
except Exception as e:
if debug:
print(f"[DEBUG] ❌ Exception: {e}")
return None
def load_model(checkpoint_path, args):
"""Load pre-trained TimesNet-PointCloud model."""
from models.TimesNet_PointCloud import TimesNetPointCloud
# Create model config
class ModelConfig:
def __init__(self, args):
self.seq_len = args.seq_len
self.pred_len = 0
self.enc_in = 3
self.c_out = 3
self.d_model = args.d_model
self.d_ff = args.d_ff
self.num_kernels = args.num_kernels
self.top_k = args.top_k
self.e_layers = args.e_layers
self.d_layers = args.d_layers
self.dropout = args.dropout
self.embed = 'timeF'
self.freq = 'h'
self.latent_dim = args.latent_dim
config = ModelConfig(args)
model = TimesNetPointCloud(config)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
if args.use_gpu:
model = model.cuda()
print(f"[INFO] Model loaded successfully from {checkpoint_path}")
return model
def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args, encoder_std=None):
"""
Generate samples directly from pre-computed latent bank.
NO REAL DATA NEEDED!
Args:
model: TimesNet model
latent_bank_path: Path to latent_bank_phase1.npz
station_id: Station ID (e.g., '0205')
num_samples: Number of samples to generate
args: Model arguments
encoder_std: Encoder std vector for noise injection
Returns:
generated_signals: (num_samples, 3, seq_len) array
real_names_used: List of lists indicating which latent vectors were used
"""
print(f"[INFO] Loading latent bank from {latent_bank_path}...")
try:
latent_data = np.load(latent_bank_path)
except Exception as e:
print(f"[ERROR] Could not load latent bank: {e}")
return None, None
# Load latent vectors for this station
latents_key = f'latents_{station_id}'
means_key = f'means_{station_id}'
stdev_key = f'stdev_{station_id}'
if latents_key not in latent_data:
print(f"[ERROR] Station {station_id} not found in latent bank!")
print(f"Available stations: {[k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')]}")
return None, None
latents = latent_data[latents_key] # (N_samples, seq_len, d_model)
means = latent_data[means_key] # (N_samples, seq_len, d_model)
stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model)
print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}")
print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...")
generated_signals = []
real_names_used = []
model.eval()
with torch.no_grad():
for i in range(num_samples):
# Bootstrap: randomly select k latent vectors with replacement
k = min(args.pcgen_k, len(latents))
selected_indices = np.random.choice(len(latents), size=k, replace=True)
# Mix latent features (average)
selected_latents = latents[selected_indices] # (k, seq_len, d_model)
selected_means = means[selected_indices] # (k, seq_len, d_model)
selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model)
mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model)
mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model)
mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model)
# NOTE: Do NOT add noise during generation (matching untitled1_gen.py)
# untitled1_gen.py only uses noise during TRAINING (Phase 1), not during generation
# if encoder_std is not None:
# noise = np.random.randn(*mixed_features.shape) * encoder_std
# mixed_features = mixed_features + noise
# Convert to torch tensors
mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model)
means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model)
stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model)
if args.use_gpu:
mixed_features_torch = mixed_features_torch.cuda()
means_b = means_b.cuda()
stdev_b = stdev_b.cuda()
# Decode
xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b)
# Store - transpose to (3, 6000)
generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) → (3, 6000)
generated_signals.append(generated_np)
# Track which latent indices were used
real_names_used.append([f"latent_{idx}" for idx in selected_indices])
if (i + 1) % 10 == 0:
print(f" Generated {i + 1}/{num_samples} samples...")
return np.array(generated_signals), real_names_used
def _preprocess_component_boore(data: np.ndarray, fs: float, corner_freq: float, filter_order: int = 2) -> np.ndarray:
"""Boore (2005) style preprocessing: detrend (linear), zero-padding, high-pass Butterworth (zero-phase)."""
from scipy.signal import butter, filtfilt
x = np.asarray(data, dtype=np.float64)
n = x.shape[0]
# Linear detrend
t = np.arange(n, dtype=np.float64)
t_mean = t.mean()
x_mean = x.mean()
denom = np.sum((t - t_mean) ** 2)
slope = 0.0 if denom == 0 else float(np.sum((t - t_mean) * (x - x_mean)) / denom)
intercept = float(x_mean - slope * t_mean)
x_detr = x - (slope * t + intercept)
# Zero-padding
Tzpad = (1.5 * filter_order) / max(corner_freq, 1e-6)
pad_samples = int(round(Tzpad * fs))
x_pad = np.concatenate([np.zeros(pad_samples, dtype=np.float64), x_detr, np.zeros(pad_samples, dtype=np.float64)])
# High-pass filter (zero-phase)
normalized = corner_freq / (fs / 2.0)
normalized = min(max(normalized, 1e-6), 0.999999)
b, a = butter(filter_order, normalized, btype='high')
y = filtfilt(b, a, x_pad)
return y
def _konno_ohmachi_smoothing(spectrum: np.ndarray, freq: np.ndarray, b: float = 40.0) -> np.ndarray:
"""Konno-Ohmachi smoothing as in MATLAB reference (O(n^2))."""
f = np.asarray(freq, dtype=np.float64).reshape(-1)
s = np.asarray(spectrum, dtype=np.float64).reshape(-1)
f = np.where(f == 0.0, 1e-12, f)
n = f.shape[0]
out = np.zeros_like(s)
for i in range(n):
w = np.exp(-b * (np.log(f / f[i])) ** 2)
w[~np.isfinite(w)] = 0.0
denom = np.sum(w)
out[i] = 0.0 if denom == 0 else float(np.sum(w * s) / denom)
return out
def _compute_hvsr_simple(signal: np.ndarray, fs: float = 100.0):
"""Compute HVSR curve using MATLAB-style pipeline (Boore HP filter + FAS + Konno-Ohmachi)."""
try:
if signal.ndim != 2 or signal.shape[1] != 3:
return None, None
if np.any(np.isnan(signal)) or np.any(np.isinf(signal)):
return None, None
# Preprocess components (Boore 2005): detrend + zero-padding + high-pass (0.05 Hz)
ew = _preprocess_component_boore(signal[:, 0], fs, 0.05, 2)
ns = _preprocess_component_boore(signal[:, 1], fs, 0.05, 2)
ud = _preprocess_component_boore(signal[:, 2], fs, 0.05, 2)
n = int(min(len(ew), len(ns), len(ud)))
if n < 16:
return None, None
ew = ew[:n]; ns = ns[:n]; ud = ud[:n]
# FFT amplitudes and linear frequency grid
half = n // 2
if half <= 1:
return None, None
freq = (np.arange(0, half, dtype=np.float64)) * (fs / n)
amp_ew = np.abs(np.fft.fft(ew))[:half]
amp_ns = np.abs(np.fft.fft(ns))[:half]
amp_ud = np.abs(np.fft.fft(ud))[:half]
# Horizontal combination via geometric mean, then Konno-Ohmachi smoothing
combined_h = np.sqrt(np.maximum(amp_ew, 0.0) * np.maximum(amp_ns, 0.0))
sm_h = _konno_ohmachi_smoothing(combined_h, freq, 40.0)
sm_v = _konno_ohmachi_smoothing(amp_ud, freq, 40.0)
sm_v_safe = np.where(sm_v <= 0.0, 1e-12, sm_v)
sm_hvsr = sm_h / sm_v_safe
# Limit to 1-20 Hz band
mask = (freq >= 1.0) & (freq <= 20.0)
if not np.any(mask):
return None, None
return freq[mask], sm_hvsr[mask]
except Exception:
return None, None
def save_generated_samples(generated_signals, real_names, station_id, output_dir):
"""Save generated samples to NPZ file with HVSR and f0 data."""
os.makedirs(output_dir, exist_ok=True)
# Compute HVSR and f0 for all generated signals
f0_list = []
hvsr_curves = []
fs = 100.0
print(f"[INFO] Computing HVSR and f0 for {len(generated_signals)} generated samples...")
for idx, sig in enumerate(generated_signals):
# sig is (3, T), need to transpose to (T, 3)
sig_t = sig.T # (T, 3)
freq, hvsr = _compute_hvsr_simple(sig_t, fs)
if freq is not None and hvsr is not None:
hvsr_curves.append((freq, hvsr))
# f0 = frequency at max HVSR
max_idx = np.argmax(hvsr)
f0 = float(freq[max_idx])
f0_list.append(f0)
# Build median HVSR curve on a fixed frequency grid (1-20 Hz, 400 points for consistency)
hvsr_freq = None
hvsr_median = None
if hvsr_curves:
# Use a fixed frequency grid for consistency with other plots
hvsr_freq = np.linspace(1.0, 20.0, 400)
# Interpolate all curves to common grid
hvsr_matrix = []
for freq, hvsr in hvsr_curves:
hvsr_interp = np.interp(hvsr_freq, freq, hvsr, left=hvsr[0], right=hvsr[-1])
hvsr_matrix.append(hvsr_interp)
hvsr_median = np.median(np.vstack(hvsr_matrix), axis=0)
# Build f0 histogram (PDF)
f0_bins = np.linspace(1.0, 20.0, 21)
f0_array = np.array(f0_list)
f0_hist, _ = np.histogram(f0_array, bins=f0_bins)
f0_pdf = f0_hist.astype(float)
f0_sum = f0_pdf.sum()
if f0_sum > 0:
f0_pdf = f0_pdf / f0_sum
# Save timeseries NPZ with HVSR data
output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz')
np.savez_compressed(
output_path,
generated_signals=generated_signals,
signals_generated=generated_signals, # Alias for compatibility
real_names=real_names,
station_id=station_id,
station=station_id, # Alias for compatibility
f0_timesnet=f0_array,
f0_bins=f0_bins,
pdf_timesnet=f0_pdf,
hvsr_freq_timesnet=hvsr_freq if hvsr_freq is not None else np.array([]),
hvsr_median_timesnet=hvsr_median if hvsr_median is not None else np.array([]),
)
print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}")
if len(f0_list) > 0:
print(f"[INFO] - f0 samples: {len(f0_list)}, median f0: {np.median(f0_array):.2f} Hz")
else:
print(f"[INFO] - No valid HVSR computed")
def fine_tune_model(model, all_station_files, args, encoder_std, epochs=10, lr=1e-4):
"""
Fine-tune the model on 5 stations with noise injection.
Matches Phase 1 training in untitled1_gen.py exactly.
"""
print("\n" + "="*80)
print("Phase 1: Fine-Tuning with Noise Injection")
print("="*80)
# Prepare data loader
all_data = []
for station_id, files in all_station_files.items():
for fpath in files:
data = load_mat_file(fpath, args.seq_len, debug=False)
if data is not None:
all_data.append(data)
if len(all_data) == 0:
print("[WARN] No data loaded for fine-tuning!")
return model
print(f"[INFO] Loaded {len(all_data)} samples for fine-tuning")
# Create optimizer (matching untitled1_gen.py Phase 1)
batch_size = 32
weight_decay = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# AMP scaler (matching untitled1_gen.py)
scaler = torch.cuda.amp.GradScaler(enabled=(args.use_gpu))
# Gradient clipping (matching untitled1_gen.py)
grad_clip = 1.0
train_losses_p1 = []
for epoch in range(epochs):
model.train()
total_loss = 0.0
total_rec = 0.0
num_batches = 0
# Shuffle data
np.random.shuffle(all_data)
for i in range(0, len(all_data), batch_size):
batch = all_data[i:i+batch_size]
if len(batch) == 0:
continue
# Stack batch
x_list = []
for sig in batch:
# sig is (3, 6000), transpose to (6000, 3)
x_list.append(sig.transpose(0, 1))
x = torch.stack(x_list, dim=0) # (batch, 6000, 3)
if args.use_gpu:
x = x.cuda()
# Zero gradients (matching untitled1_gen.py)
optimizer.zero_grad(set_to_none=True)
# Forward with AMP and noise injection (matching untitled1_gen.py Phase 1)
with torch.cuda.amp.autocast(enabled=(args.use_gpu)):
enc_out, means_b, stdev_b = model.encode_features_for_reconstruction(x)
# Add noise if encoder_std available (matching untitled1_gen.py line 945-948)
if encoder_std is not None:
std_vec = torch.from_numpy(encoder_std).to(enc_out.device).float()
noise = torch.randn_like(enc_out) * std_vec.view(1, 1, -1) * 1.0 # noise_std_scale=1.0
enc_out = enc_out + noise
# Decode
x_hat = model.project_features_for_reconstruction(enc_out, means_b, stdev_b)
# Reconstruction loss (MSE, matching untitled1_gen.py)
loss_rec = torch.nn.functional.mse_loss(x_hat, x)
loss = loss_rec
# Backward with gradient scaling (matching untitled1_gen.py)
scaler.scale(loss).backward()
# Gradient clipping (matching untitled1_gen.py)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
# Optimizer step with scaler (matching untitled1_gen.py)
scaler.step(optimizer)
scaler.update()
total_loss += float(loss.detach().cpu())
total_rec += float(loss_rec.detach().cpu())
num_batches += 1
# Scheduler step (matching untitled1_gen.py)
scheduler.step()
avg_loss = total_loss / max(1, num_batches)
avg_rec = total_rec / max(1, num_batches)
train_losses_p1.append(avg_loss)
print(f"[P1] epoch {epoch+1}/{epochs} loss={avg_loss:.4f} (rec={avg_rec:.4f})")
print("[INFO] Phase 1 fine-tuning complete!")
# Save fine-tuned model (matching untitled1_gen.py Phase 1 checkpoint)
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
fine_tuned_path = os.path.join(checkpoint_dir, 'timesnet_pointcloud_phase1_finetuned.pth')
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_losses_phase1': train_losses_p1,
'phase': 'phase1'
}, fine_tuned_path)
print(f"[INFO] ✓ Fine-tuned model saved to: {fine_tuned_path}")
return model
def plot_sample_preview(generated_signals, station_id, output_dir, num_preview=2):
"""Create preview plots."""
os.makedirs(output_dir, exist_ok=True)
for i in range(min(num_preview, len(generated_signals))):
fig, axes = plt.subplots(3, 1, figsize=(12, 8))
signal = generated_signals[i]
channel_names = ['E-W', 'N-S', 'U-D']
for ch, (ax, name) in enumerate(zip(axes, channel_names)):
ax.plot(signal[ch], linewidth=0.8)
ax.set_ylabel(f'{name}\nAmplitude', fontsize=10, fontweight='bold')
ax.grid(True, alpha=0.3)
axes[-1].set_xlabel('Time Steps', fontsize=10, fontweight='bold')
fig.suptitle(f'Generated Sample - Station {station_id}', fontsize=12, fontweight='bold')
plt.tight_layout()
output_path = os.path.join(output_dir, f'station_{station_id}_preview_{i}.png')
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"[INFO] Saved {min(num_preview, len(generated_signals))} preview plots to {output_dir}")
def main():
parser = argparse.ArgumentParser(description='Generate seismic samples (simplified version)')
parser.add_argument('--checkpoint', type=str,
default=r'D:\Baris\codes\Time-Series-Library-main\checkpoints\timesnet_pointcloud_phase1_final.pth',
help='Path to pre-trained model checkpoint')
parser.add_argument('--latent_bank', type=str,
default=r'D:\Baris\codes\Time-Series-Library-main\checkpoints\latent_bank_phase1.npz',
help='Path to latent bank NPZ file')
parser.add_argument('--num_samples', type=int, default=50,
help='Number of samples to generate per station')
parser.add_argument('--output_dir', type=str, default='./generated_samples',
help='Output directory')
parser.add_argument('--num_preview', type=int, default=2,
help='Number of preview plots per station')
parser.add_argument('--stations', type=str, nargs='+', default=['0205', '1716', '2020', '3130', '4628'],
help='Target station IDs')
parser.add_argument('--data_root', type=str, default=r"D:\Baris\5stats/",
help='Root path to seismic data (only needed if --fine_tune is used)')
parser.add_argument('--fine_tune', action='store_true',
help='Fine-tune the model before generation (use with Phase 0 checkpoint)')
parser.add_argument('--fine_tune_epochs', type=int, default=10,
help='Number of fine-tuning epochs')
parser.add_argument('--fine_tune_lr', type=float, default=1e-4,
help='Learning rate for fine-tuning')
args_cli = parser.parse_args()
# Check checkpoint
if not os.path.exists(args_cli.checkpoint):
print(f"\n{'='*80}")
print(f"❌ ERROR: Checkpoint not found!")
print(f"{'='*80}")
print(f"\nLooking for: {args_cli.checkpoint}")
return
# Create configuration
args = SimpleArgs()
print("="*80)
print("TimesNet-Gen Sample Generation (Simplified)")
print("="*80)
print(f"Checkpoint: {args_cli.checkpoint}")
print(f"Target stations: {args_cli.stations}")
print(f"Samples per station: {args_cli.num_samples}")
print(f"Output directory: {args_cli.output_dir}")
print("="*80)
# Set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Load model
model = load_model(args_cli.checkpoint, args)
# Try to load encoder_std from Phase 0 (only needed if fine-tuning)
encoder_std_path = './pcgen_stats/encoder_feature_std.npy'
encoder_std = None
if os.path.exists(encoder_std_path):
encoder_std = np.load(encoder_std_path)
print(f"[INFO] Loaded encoder_std from {encoder_std_path} (shape: {encoder_std.shape})")
print(f"[INFO] encoder_std loaded (used only for fine-tuning, NOT for generation)")
else:
print(f"[INFO] No encoder_std found (not needed for generation, only for fine-tuning)")
# Check if latent bank exists
if not os.path.exists(args_cli.latent_bank):
print(f"\n❌ ERROR: Latent bank not found!")
print(f"Looking for: {args_cli.latent_bank}")
print(f"\nPlease run untitled1_gen.py first to generate the latent bank.")
return
print(f"[INFO] Using latent bank: {args_cli.latent_bank}")
# Fine-tune if requested (requires real data)
if args_cli.fine_tune:
print("\n[INFO] Fine-tuning enabled! Loading real data...")
all_station_files = {}
for station_id in args_cli.stations:
# Find all .mat files for this station
pattern = os.path.join(args_cli.data_root, f"*{station_id}*.mat")
station_files = glob.glob(pattern)
if len(station_files) == 0:
print(f"[WARN] No files found for station {station_id}")
else:
print(f"[INFO] Found {len(station_files)} files for station {station_id}")
all_station_files[station_id] = station_files
if len(all_station_files) == 0:
print(f"\n❌ ERROR: No data files found in {args_cli.data_root}")
return
model = fine_tune_model(model, all_station_files, args, encoder_std,
epochs=args_cli.fine_tune_epochs,
lr=args_cli.fine_tune_lr)
# Create output directories
npz_output_dir = os.path.join(args_cli.output_dir, 'generated_timeseries_npz')
plot_output_dir = os.path.join(args_cli.output_dir, 'preview_plots')
# Generate samples for each station (from latent bank)
print("\n[INFO] Generating samples from latent bank...")
for station_id in args_cli.stations:
print(f"\n{'='*60}")
print(f"Processing Station: {station_id}")
print(f"{'='*60}")
generated_signals, real_names = generate_samples_from_latent_bank(
model, args_cli.latent_bank, station_id, args_cli.num_samples, args, encoder_std
)
if generated_signals is not None:
# Save to NPZ
save_generated_samples(generated_signals, real_names, station_id, npz_output_dir)
# Create preview plots
plot_sample_preview(generated_signals, station_id, plot_output_dir, args_cli.num_preview)
print("\n" + "="*80)
print("Generation Complete!")
print("="*80)
print(f"Generated samples saved to: {npz_output_dir}")
print(f"Preview plots saved to: {plot_output_dir}")
# Debug: Show how many samples were generated per station
print("\n[DEBUG] Generated samples per station:")
for station_id in args_cli.stations:
npz_path = os.path.join(npz_output_dir, f'station_{station_id}_generated_timeseries.npz')
if os.path.exists(npz_path):
try:
data = np.load(npz_path, allow_pickle=True)
if 'signals_generated' in data:
n_samples = data['signals_generated'].shape[0]
print(f" Station {station_id}: {n_samples} samples")
except Exception as e:
print(f" Station {station_id}: Error loading NPZ - {e}")
print("="*80)
# Create HVSR comparison plots (import plot_combined_hvsr_all_sources and call main)
print("\n[INFO] Creating HVSR comparison plots (matrices, HVSR curves, f0 distributions)...")
print("[INFO] Only plotting TimesNet-Gen vs Real (no Recon/VAE)")
try:
import sys
# Import the plotting module
import plot_combined_hvsr_all_sources as hvsr_plotter
# Override sys.argv to pass arguments to the plotter
# Only provide gen_dir and gen_ts_dir, explicitly disable others with empty strings
original_argv = sys.argv
sys.argv = [
'plot_combined_hvsr_all_sources.py',
'--gen_dir', npz_output_dir, # Use our generated NPZs as gen_dir (they now have HVSR/f0 data)
'--gen_ts_dir', npz_output_dir, # Also use for timeseries plots
'--out', os.path.join(args_cli.output_dir, 'hvsr_analysis'),
'--recon_dir', '', # Explicitly empty to disable auto-default
'--vae_dir', '', # Explicitly empty to disable auto-default
'--vae_gen_dir', '', # Explicitly empty to disable auto-default
]
# Call the main plotting function
hvsr_plotter.main()
# Restore original argv
sys.argv = original_argv
print(f"[INFO] ✅ HVSR analysis complete! Plots saved to: {os.path.join(args_cli.output_dir, 'hvsr_analysis')}")
except Exception as e:
import traceback
print(f"[WARN] Could not create HVSR plots: {e}")
traceback.print_exc()
if __name__ == "__main__":
main()