| | from ml_collections import config_dict |
| | import yaml |
| | from diffusers.schedulers import ( |
| | DDIMScheduler, |
| | EulerAncestralDiscreteScheduler, |
| | EulerDiscreteScheduler, |
| | DDPMScheduler, |
| | ) |
| | from inversion_utils import ( |
| | deterministic_ddim_step, |
| | deterministic_ddpm_step, |
| | deterministic_euler_step, |
| | deterministic_non_ancestral_euler_step, |
| | ) |
| |
|
| | BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"] |
| | SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"] |
| | MODELS = [ |
| | "stabilityai/sdxl-turbo", |
| | "stabilityai/stable-diffusion-xl-base-1.0", |
| | "CompVis/stable-diffusion-v1-4", |
| | ] |
| |
|
| | def get_num_steps_actual(cfg): |
| | return ( |
| | cfg.num_steps_inversion |
| | - cfg.step_start |
| | + (1 if cfg.clean_step_timestep > 0 else 0) |
| | if cfg.timesteps is None |
| | else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0) |
| | ) |
| |
|
| |
|
| | def get_config(args): |
| | if args.config_from_file and args.config_from_file != "": |
| | with open(args.config_from_file, "r") as f: |
| | cfg = config_dict.ConfigDict(yaml.safe_load(f)) |
| | |
| | num_steps_actual = get_num_steps_actual(cfg) |
| |
|
| | else: |
| | cfg = config_dict.ConfigDict() |
| |
|
| | cfg.seed = 2 |
| | cfg.self_r = 0.5 |
| | cfg.cross_r = 0.9 |
| | cfg.eta = 1 |
| | cfg.scheduler_type = SCHEDULERS[0] |
| |
|
| | cfg.num_steps_inversion = 50 |
| | cfg.step_start = 20 |
| | cfg.timesteps = None |
| | cfg.noise_timesteps = None |
| | num_steps_actual = get_num_steps_actual(cfg) |
| | cfg.ws1 = [2] * num_steps_actual |
| | cfg.ws2 = [1] * num_steps_actual |
| | cfg.real_cfg_scale = 0 |
| | cfg.real_cfg_scale_save = 0 |
| | cfg.breakdown = BREAKDOWNS[1] |
| | cfg.noise_shift_delta = 1 |
| | cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] |
| |
|
| | cfg.clean_step_timestep = 0 |
| |
|
| | cfg.model = MODELS[1] |
| |
|
| | if cfg.scheduler_type == "ddim": |
| | cfg.scheduler_class = DDIMScheduler |
| | cfg.step_function = deterministic_ddim_step |
| | elif cfg.scheduler_type == "ddpm": |
| | cfg.scheduler_class = DDPMScheduler |
| | cfg.step_function = deterministic_ddpm_step |
| | elif cfg.scheduler_type == "euler": |
| | cfg.scheduler_class = EulerAncestralDiscreteScheduler |
| | cfg.step_function = deterministic_euler_step |
| | elif cfg.scheduler_type == "euler_non_ancestral": |
| | cfg.scheduler_class = EulerDiscreteScheduler |
| | cfg.step_function = deterministic_non_ancestral_euler_step |
| | else: |
| | raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}") |
| |
|
| | with cfg.ignore_type(): |
| | if isinstance(cfg.max_norm_zs, (int, float)): |
| | cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual |
| |
|
| | if isinstance(cfg.ws1, (int, float)): |
| | cfg.ws1 = [cfg.ws1] * num_steps_actual |
| |
|
| | if isinstance(cfg.ws2, (int, float)): |
| | cfg.ws2 = [cfg.ws2] * num_steps_actual |
| |
|
| | if not hasattr(cfg, "update_eta"): |
| | cfg.update_eta = False |
| |
|
| | if not hasattr(cfg, "save_timesteps"): |
| | cfg.save_timesteps = None |
| |
|
| | if not hasattr(cfg, "scheduler_timesteps"): |
| | cfg.scheduler_timesteps = None |
| |
|
| | assert ( |
| | cfg.scheduler_type == "ddpm" or cfg.timesteps is None |
| | ), "timesteps must be None for ddim/euler" |
| |
|
| | cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] |
| | assert ( |
| | len(cfg.max_norm_zs) == num_steps_actual |
| | ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})" |
| |
|
| | assert ( |
| | len(cfg.ws1) == num_steps_actual |
| | ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})" |
| |
|
| | assert ( |
| | len(cfg.ws2) == num_steps_actual |
| | ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})" |
| |
|
| | assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == ( |
| | num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0) |
| | ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})" |
| |
|
| | assert cfg.save_timesteps is None or len(cfg.save_timesteps) == ( |
| | num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0) |
| | ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})" |
| |
|
| | return cfg |
| |
|
| |
|
| | def get_config_name(config, args): |
| | if args.folder_name is not None and args.folder_name != "": |
| | return args.folder_name |
| | timesteps_str = ( |
| | f"step_start {config.step_start}" |
| | if config.timesteps is None |
| | else f"timesteps {config.timesteps}" |
| | ) |
| | return f"""\ |
| | ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \ |
| | real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \ |
| | scheduler_type {config.scheduler_type} fp16 {args.fp16}\ |
| | """ |
| |
|