Text-Conditional QuickDraw Diffusion Model
A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.
Model Description
This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:
- CLIP text encoder (
openai/clip-vit-base-patch32) for text conditioning - DDPM for the diffusion process (1000 timesteps)
- Classifier-free guidance for improved text-image alignment
- Trained on Google QuickDraw dataset
Model Details
- Model Type: Text-conditional DDPM diffusion model
- Architecture: U-Net with cross-attention for text conditioning
- Image Size: 64x64 grayscale
- Base Channels: 256
- Text Encoder: CLIP ViT-B/32 (frozen)
- Training Steps: 100 epochs
- Diffusion Timesteps: 1000
- Guidance Scale: 5.0 (default)
Training Configuration
- Dataset: Xenova/quickdraw-small (5 classes)
- Batch Size: 128 (32 per GPU Γ 4 GPUs)
- Learning Rate: 1e-4
- CFG Drop Probability: 0.15
- Optimizer: Adam
Usage
Installation
pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm
Generate Images
import torch
from model import TextConditionedUNet
from scheduler import SimpleDDPMScheduler
from text_encoder import CLIPTextEncoder
from generate import generate_samples
# Load checkpoint
checkpoint_path = "text_diffusion_final_epoch_100.pt"
checkpoint = torch.load(checkpoint_path)
# Initialize model
model = TextConditionedUNet(text_dim=512).cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Initialize text encoder
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
text_encoder.eval()
# Generate samples
scheduler = SimpleDDPMScheduler(1000)
prompt = "a drawing of a cat"
num_samples = 4
guidance_scale = 5.0
with torch.no_grad():
text_embedding = text_encoder(prompt)
text_embeddings = text_embedding.repeat(num_samples, 1)
shape = (num_samples, 1, 64, 64)
samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)
Command Line Usage
# Generate samples
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
--prompt "a drawing of a fire truck" \
--num-samples 4 \
--guidance-scale 5.0
# Visualize denoising process
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
--prompt "a drawing of a cat" \
--num-steps 10
Example Prompts
Try these prompts for best results:
- "a drawing of a cat"
- "a drawing of a fire truck"
- "a drawing of an airplane"
- "a drawing of a house"
- "a drawing of a tree"
Note: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".
Classifier-Free Guidance
The model supports classifier-free guidance to improve text-image alignment:
guidance_scale = 1.0: No guidance (pure conditional generation)guidance_scale = 3.0-7.0: Recommended range (default: 5.0)- Higher values: Stronger adherence to text prompt (may reduce diversity)
Model Architecture
U-Net Structure
Input: (batch, 1, 64, 64)
βββ Down Block 1: 1 β 256 channels
βββ Down Block 2: 256 β 512 channels
βββ Down Block 3: 512 β 512 channels
βββ Middle Block: 512 channels
βββ Up Block 3: 1024 β 512 channels (with skip connections)
βββ Up Block 2: 768 β 256 channels (with skip connections)
βββ Up Block 1: 512 β 1 channel (with skip connections)
Output: (batch, 1, 64, 64) - predicted noise
Text Conditioning
- Text prompts encoded via CLIP ViT-B/32
- 512-dimensional text embeddings
- Injected into U-Net via cross-attention
- Classifier-free guidance with 15% dropout during training
Training Details
- Framework: PyTorch 2.0+
- Hardware: 4x NVIDIA GPUs
- Training Time: ~100 epochs
- Dataset: Google QuickDraw sketches (5 classes)
- Noise Schedule: Linear (Ξ² from 0.0001 to 0.02)
Limitations
- Limited to 64x64 resolution
- Grayscale output only
- Best performance on simple objects from QuickDraw classes
- May not generalize well to complex or out-of-distribution prompts
Citation
If you use this model, please cite:
@misc{quickdraw-text-diffusion,
title={Text-Conditional QuickDraw Diffusion Model},
author={Your Name},
year={2024},
howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
}
License
MIT License
Acknowledgments
- Google QuickDraw dataset
- OpenAI CLIP
- DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)