Add TCD scheduler (#6086)

Adds the TCD scheduler to better support.
https://huggingface.co/h1t/TCD-SDXL-LoRA or checkpoints that have been
made with TCD

Example:
TCD Lora with Euler A

![b0ad6174-cd2b-49fe-ae42-3a83bc6ae571](https://github.com/invoke-ai/InvokeAI/assets/82827604/d823cb2f-4d9c-4f93-9fc2-e63773a378b6)

TCD Lora with TCD scheduler

![74495a51-eeac-45e6-9983-fb6551a5bdef](https://github.com/invoke-ai/InvokeAI/assets/82827604/c87604d8-a44e-4fb9-a7be-ef2600784727)
This commit is contained in:
blessedcoolant 2024-05-01 12:57:01 +05:30 committed by GitHub
commit 4a250bdf9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 73 additions and 16 deletions

View File

@ -3,7 +3,7 @@ import inspect
import math import math
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
import einops import einops
import numpy as np import numpy as np
@ -11,7 +11,6 @@ import numpy.typing as npt
import torch import torch
import torchvision import torchvision
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter from diffusers.models.adapter import T2IAdapter
@ -21,9 +20,12 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -521,9 +523,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
if is_sdxl: if is_sdxl:
return SDXLConditioningInfo( return (
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
), regions regions,
)
return BasicConditioningInfo(embeds=text_embedding), regions return BasicConditioningInfo(embeds=text_embedding), regions
def get_conditioning_data( def get_conditioning_data(
@ -825,7 +828,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_start: float, denoising_start: float,
denoising_end: float, denoising_end: float,
seed: int, seed: int,
) -> Tuple[int, List[int], int]: ) -> Tuple[int, List[int], int, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin) assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu") scheduler.set_timesteps(steps, device="cpu")
@ -853,13 +856,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs = {} scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step) scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters: if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with # At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility. # reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)} scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs

View File

@ -13,6 +13,7 @@ from diffusers import (
LCMScheduler, LCMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
TCDScheduler,
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
@ -40,4 +41,5 @@ SCHEDULER_MAP = {
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}), "dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}), "unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
"lcm": (LCMScheduler, {}), "lcm": (LCMScheduler, {}),
"tcd": (TCDScheduler, {}),
} }

View File

@ -49,6 +49,7 @@ export const zSchedulerField = z.enum([
'euler_a', 'euler_a',
'kdpm_2_a', 'kdpm_2_a',
'lcm', 'lcm',
'tcd',
]); ]);
export type SchedulerField = z.infer<typeof zSchedulerField>; export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion // #endregion

View File

@ -75,4 +75,5 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
{ value: 'euler_a', label: 'Euler Ancestral' }, { value: 'euler_a', label: 'Euler Ancestral' },
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' }, { value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
{ value: 'lcm', label: 'LCM' }, { value: 'lcm', label: 'LCM' },
{ value: 'tcd', label: 'TCD' },
].sort((a, b) => a.label.localeCompare(b.label)); ].sort((a, b) => a.label.localeCompare(b.label));

File diff suppressed because one or more lines are too long