File size: 5,657 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from sample import frequency_aware_sample
import numpy as np

def debug_model_predictions():
    """Debug what the model is actually predicting"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Find latest checkpoint
    log_dirs = []
    if os.path.exists('./logs'):
        for item in os.listdir('./logs'):
            if os.path.isdir(os.path.join('./logs', item)):
                log_dirs.append(item)
    
    if not log_dirs:
        print("No log directories found!")
        return
    
    latest_log = sorted(log_dirs)[-1]
    log_path = os.path.join('./logs', latest_log)
    
    checkpoint_files = []
    for file in os.listdir(log_path):
        if file.startswith('model_epoch_') and file.endswith('.pth'):
            epoch = int(file.split('_')[2].split('.')[0])
            checkpoint_files.append((epoch, file))
    
    if not checkpoint_files:
        print("No checkpoint files found!")
        return
    
    # Get latest checkpoint
    checkpoint_files.sort()
    latest_epoch, latest_file = checkpoint_files[-1]
    checkpoint_path = os.path.join(log_path, latest_file)
    
    print(f"Loading {latest_file}")
    
    # Load model
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = checkpoint.get('config', Config())
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.eval()
    
    print("\n=== DEBUGGING MODEL PREDICTIONS ===")
    
    with torch.no_grad():
        # Create a simple test input
        x_test = torch.randn(1, 3, 64, 64, device=device)
        
        # Test at different timesteps
        timesteps_to_test = [0, 50, 100, 250, 499]
        
        for t_val in timesteps_to_test:
            t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long)
            
            # Get model prediction
            pred_noise = model(x_test, t_tensor)
            
            print(f"\nTimestep {t_val}:")
            print(f"  Input range: [{x_test.min().item():.3f}, {x_test.max().item():.3f}]")
            print(f"  Input mean/std: {x_test.mean().item():.3f} / {x_test.std().item():.3f}")
            print(f"  Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]")
            print(f"  Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}")
            
            # Check if prediction is reasonable
            if torch.isnan(pred_noise).any():
                print(f"  ❌ NaN detected in predictions!")
            elif pred_noise.std().item() < 0.01:
                print(f"  ⚠️  Very low variance - model might be collapsed")
            elif pred_noise.std().item() > 10:
                print(f"  ⚠️  Very high variance - model might be unstable")
            else:
                print(f"  ✓ Prediction variance looks reasonable")
    
    print("\n=== TESTING TRAINING DATA SIMULATION ===")
    
    # Simulate what happens during training
    with torch.no_grad():
        # Create clean image
        x0 = torch.randn(1, 3, 64, 64, device=device) * 0.5  # More reasonable range
        t = torch.randint(100, 400, (1,), device=device)  # Mid-range timestep
        
        # Apply noise like in training
        xt, noise_target = noise_scheduler.apply_noise(x0, t)
        
        # Get model prediction
        pred_noise = model(xt, t)
        
        print(f"\nTraining simulation:")
        print(f"  Clean image range: [{x0.min().item():.3f}, {x0.max().item():.3f}]")
        print(f"  Noisy image range: [{xt.min().item():.3f}, {xt.max().item():.3f}]")
        print(f"  Target noise range: [{noise_target.min().item():.3f}, {noise_target.max().item():.3f}]")
        print(f"  Target noise mean/std: {noise_target.mean().item():.3f} / {noise_target.std().item():.3f}")
        print(f"  Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]")
        print(f"  Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}")
        
        # Calculate MSE
        mse = torch.mean((pred_noise - noise_target) ** 2)
        print(f"  MSE between prediction and target: {mse.item():.6f}")
        
        if mse.item() > 1.0:
            print(f"  ⚠️  High MSE suggests poor training")
        elif mse.item() < 0.001:
            print(f"  ✓ Very low MSE - model learned well")
        else:
            print(f"  ✓ Reasonable MSE")
    
    print("\n=== ATTEMPTING CORRECTED SAMPLING ===")
    
    # Try different sampling approaches
    try:
        samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=4)
        save_image(grid, "debug_samples.png", normalize=False)
        print(f"Samples saved to debug_samples.png")
        
        print(f"Sample statistics:")
        print(f"  Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]")
        print(f"  Mean: {samples.mean().item():.3f}")
        print(f"  Std: {samples.std().item():.3f}")
        
    except Exception as e:
        print(f"Sampling failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    debug_model_predictions()