File size: 2,984 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torchvision
from torchvision.utils import save_image
import os
from config import Config

def simple_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
    """Standard DDPM sampling - this should actually work"""
    config = Config()
    model.eval()
    
    with torch.no_grad():
        # Start with random noise
        x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device)
        
        print(f"Starting reverse diffusion for {n_samples} samples...")
        
        # Move scheduler tensors to device
        alphas = noise_scheduler.alphas.to(device)
        alpha_bars = noise_scheduler.alpha_bars.to(device)
        betas = noise_scheduler.betas.to(device)
        
        # Reverse diffusion process
        for step, t in enumerate(reversed(range(config.T))):
            if step % 100 == 0:
                print(f"Step {step}/{config.T}, t={t}")
            
            t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
            
            # Predict noise
            pred_noise = model(x, t_tensor)
            
            # Get schedule parameters
            alpha_t = alphas[t]
            alpha_bar_t = alpha_bars[t]
            beta_t = betas[t]
            
            # Standard DDPM reverse step
            if t > 0:
                alpha_bar_prev = alpha_bars[t-1]
                
                # Predict x0
                pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
                
                # Compute mean
                mean = (torch.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t)) * pred_x0 + \
                       (torch.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t)) * x
                
                # Add noise
                noise = torch.randn_like(x)
                variance = (1 - alpha_bar_prev) / (1 - alpha_bar_t) * beta_t
                x = mean + torch.sqrt(variance) * noise
            else:
                # Final step
                x = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
        
        # Clamp to valid range
        x = torch.clamp(x, -1, 1)
        
        # Debug: print sample statistics
        if epoch is not None and epoch % 10 == 0:
            print(f"Sample stats at epoch {epoch}: range [{x.min().item():.3f}, {x.max().item():.3f}], mean {x.mean().item():.3f}")
        
        grid = torchvision.utils.make_grid(x, nrow=2, normalize=True)
        
        if writer:
            writer.add_image('Samples', grid, epoch)
        
        if epoch is not None:
            os.makedirs("samples", exist_ok=True)
            save_image(grid, f"samples/epoch_{epoch}.png")
            
        return x, grid

# Use the simple sampler
def sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4):
    return simple_sample(model, noise_scheduler, device, epoch, writer, n_samples)