File size: 5,657 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from sample import frequency_aware_sample
import numpy as np
def debug_model_predictions():
"""Debug what the model is actually predicting"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Find latest checkpoint
log_dirs = []
if os.path.exists('./logs'):
for item in os.listdir('./logs'):
if os.path.isdir(os.path.join('./logs', item)):
log_dirs.append(item)
if not log_dirs:
print("No log directories found!")
return
latest_log = sorted(log_dirs)[-1]
log_path = os.path.join('./logs', latest_log)
checkpoint_files = []
for file in os.listdir(log_path):
if file.startswith('model_epoch_') and file.endswith('.pth'):
epoch = int(file.split('_')[2].split('.')[0])
checkpoint_files.append((epoch, file))
if not checkpoint_files:
print("No checkpoint files found!")
return
# Get latest checkpoint
checkpoint_files.sort()
latest_epoch, latest_file = checkpoint_files[-1]
checkpoint_path = os.path.join(log_path, latest_file)
print(f"Loading {latest_file}")
# Load model
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint.get('config', Config())
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print("\n=== DEBUGGING MODEL PREDICTIONS ===")
with torch.no_grad():
# Create a simple test input
x_test = torch.randn(1, 3, 64, 64, device=device)
# Test at different timesteps
timesteps_to_test = [0, 50, 100, 250, 499]
for t_val in timesteps_to_test:
t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long)
# Get model prediction
pred_noise = model(x_test, t_tensor)
print(f"\nTimestep {t_val}:")
print(f" Input range: [{x_test.min().item():.3f}, {x_test.max().item():.3f}]")
print(f" Input mean/std: {x_test.mean().item():.3f} / {x_test.std().item():.3f}")
print(f" Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]")
print(f" Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}")
# Check if prediction is reasonable
if torch.isnan(pred_noise).any():
print(f" ❌ NaN detected in predictions!")
elif pred_noise.std().item() < 0.01:
print(f" ⚠️ Very low variance - model might be collapsed")
elif pred_noise.std().item() > 10:
print(f" ⚠️ Very high variance - model might be unstable")
else:
print(f" ✓ Prediction variance looks reasonable")
print("\n=== TESTING TRAINING DATA SIMULATION ===")
# Simulate what happens during training
with torch.no_grad():
# Create clean image
x0 = torch.randn(1, 3, 64, 64, device=device) * 0.5 # More reasonable range
t = torch.randint(100, 400, (1,), device=device) # Mid-range timestep
# Apply noise like in training
xt, noise_target = noise_scheduler.apply_noise(x0, t)
# Get model prediction
pred_noise = model(xt, t)
print(f"\nTraining simulation:")
print(f" Clean image range: [{x0.min().item():.3f}, {x0.max().item():.3f}]")
print(f" Noisy image range: [{xt.min().item():.3f}, {xt.max().item():.3f}]")
print(f" Target noise range: [{noise_target.min().item():.3f}, {noise_target.max().item():.3f}]")
print(f" Target noise mean/std: {noise_target.mean().item():.3f} / {noise_target.std().item():.3f}")
print(f" Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]")
print(f" Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}")
# Calculate MSE
mse = torch.mean((pred_noise - noise_target) ** 2)
print(f" MSE between prediction and target: {mse.item():.6f}")
if mse.item() > 1.0:
print(f" ⚠️ High MSE suggests poor training")
elif mse.item() < 0.001:
print(f" ✓ Very low MSE - model learned well")
else:
print(f" ✓ Reasonable MSE")
print("\n=== ATTEMPTING CORRECTED SAMPLING ===")
# Try different sampling approaches
try:
samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=4)
save_image(grid, "debug_samples.png", normalize=False)
print(f"Samples saved to debug_samples.png")
print(f"Sample statistics:")
print(f" Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]")
print(f" Mean: {samples.mean().item():.3f}")
print(f" Std: {samples.std().item():.3f}")
except Exception as e:
print(f"Sampling failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
debug_model_predictions()
|