Text-Conditional QuickDraw Diffusion Model

A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.

Model Description

This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:

  • CLIP text encoder (openai/clip-vit-base-patch32) for text conditioning
  • DDPM for the diffusion process (1000 timesteps)
  • Classifier-free guidance for improved text-image alignment
  • Trained on Google QuickDraw dataset

Model Details

  • Model Type: Text-conditional DDPM diffusion model
  • Architecture: U-Net with cross-attention for text conditioning
  • Image Size: 64x64 grayscale
  • Base Channels: 256
  • Text Encoder: CLIP ViT-B/32 (frozen)
  • Training Steps: 100 epochs
  • Diffusion Timesteps: 1000
  • Guidance Scale: 5.0 (default)

Training Configuration

  • Dataset: Xenova/quickdraw-small (5 classes)
  • Batch Size: 128 (32 per GPU Γ— 4 GPUs)
  • Learning Rate: 1e-4
  • CFG Drop Probability: 0.15
  • Optimizer: Adam

Usage

Installation

pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm

Generate Images

import torch
from model import TextConditionedUNet
from scheduler import SimpleDDPMScheduler
from text_encoder import CLIPTextEncoder
from generate import generate_samples

# Load checkpoint
checkpoint_path = "text_diffusion_final_epoch_100.pt"
checkpoint = torch.load(checkpoint_path)

# Initialize model
model = TextConditionedUNet(text_dim=512).cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Initialize text encoder
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
text_encoder.eval()

# Generate samples
scheduler = SimpleDDPMScheduler(1000)
prompt = "a drawing of a cat"
num_samples = 4
guidance_scale = 5.0

with torch.no_grad():
    text_embedding = text_encoder(prompt)
    text_embeddings = text_embedding.repeat(num_samples, 1)

    shape = (num_samples, 1, 64, 64)
    samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)

Command Line Usage

# Generate samples
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
                   --prompt "a drawing of a fire truck" \
                   --num-samples 4 \
                   --guidance-scale 5.0

# Visualize denoising process
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
                               --prompt "a drawing of a cat" \
                               --num-steps 10

Example Prompts

Try these prompts for best results:

  • "a drawing of a cat"
  • "a drawing of a fire truck"
  • "a drawing of an airplane"
  • "a drawing of a house"
  • "a drawing of a tree"

Note: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".

Classifier-Free Guidance

The model supports classifier-free guidance to improve text-image alignment:

  • guidance_scale = 1.0: No guidance (pure conditional generation)
  • guidance_scale = 3.0-7.0: Recommended range (default: 5.0)
  • Higher values: Stronger adherence to text prompt (may reduce diversity)

Model Architecture

U-Net Structure

Input: (batch, 1, 64, 64)
β”œβ”€β”€ Down Block 1: 1 β†’ 256 channels
β”œβ”€β”€ Down Block 2: 256 β†’ 512 channels
β”œβ”€β”€ Down Block 3: 512 β†’ 512 channels
β”œβ”€β”€ Middle Block: 512 channels
β”œβ”€β”€ Up Block 3: 1024 β†’ 512 channels (with skip connections)
β”œβ”€β”€ Up Block 2: 768 β†’ 256 channels (with skip connections)
└── Up Block 1: 512 β†’ 1 channel (with skip connections)
Output: (batch, 1, 64, 64) - predicted noise

Text Conditioning

  • Text prompts encoded via CLIP ViT-B/32
  • 512-dimensional text embeddings
  • Injected into U-Net via cross-attention
  • Classifier-free guidance with 15% dropout during training

Training Details

  • Framework: PyTorch 2.0+
  • Hardware: 4x NVIDIA GPUs
  • Training Time: ~100 epochs
  • Dataset: Google QuickDraw sketches (5 classes)
  • Noise Schedule: Linear (Ξ² from 0.0001 to 0.02)

Limitations

  • Limited to 64x64 resolution
  • Grayscale output only
  • Best performance on simple objects from QuickDraw classes
  • May not generalize well to complex or out-of-distribution prompts

Citation

If you use this model, please cite:

@misc{quickdraw-text-diffusion,
  title={Text-Conditional QuickDraw Diffusion Model},
  author={Your Name},
  year={2024},
  howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
}

License

MIT License

Acknowledgments

  • Google QuickDraw dataset
  • OpenAI CLIP
  • DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
  • Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train jamesaasher/quickdraw-text-diffusion