mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add TCD scheduler (#6086)
Adds the TCD scheduler to better support. https://huggingface.co/h1t/TCD-SDXL-LoRA or checkpoints that have been made with TCD Example: TCD Lora with Euler A ![b0ad6174-cd2b-49fe-ae42-3a83bc6ae571](https://github.com/invoke-ai/InvokeAI/assets/82827604/d823cb2f-4d9c-4f93-9fc2-e63773a378b6) TCD Lora with TCD scheduler ![74495a51-eeac-45e6-9983-fb6551a5bdef](https://github.com/invoke-ai/InvokeAI/assets/82827604/c87604d8-a44e-4fb9-a7be-ef2600784727)
This commit is contained in:
commit
4a250bdf9c
@ -3,7 +3,7 @@ import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@ -11,7 +11,6 @@ import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.adapter import T2IAdapter
|
||||
@ -21,9 +20,12 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
@ -521,9 +523,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
|
||||
), regions
|
||||
return (
|
||||
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
|
||||
regions,
|
||||
)
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
|
||||
def get_conditioning_data(
|
||||
@ -825,7 +828,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
) -> Tuple[int, List[int], int, Dict[str, Any]]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
@ -853,13 +856,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
scheduler_step_kwargs = {}
|
||||
scheduler_step_kwargs: Dict[str, Any] = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
if "generator" in scheduler_step_signature.parameters:
|
||||
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||
# reproducibility.
|
||||
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
|
||||
if isinstance(scheduler, TCDScheduler):
|
||||
scheduler_step_kwargs.update({"eta": 1.0})
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
|
||||
|
@ -13,6 +13,7 @@ from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
TCDScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
@ -40,4 +41,5 @@ SCHEDULER_MAP = {
|
||||
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
||||
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
|
||||
"lcm": (LCMScheduler, {}),
|
||||
"tcd": (TCDScheduler, {}),
|
||||
}
|
||||
|
@ -49,6 +49,7 @@ export const zSchedulerField = z.enum([
|
||||
'euler_a',
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
'tcd',
|
||||
]);
|
||||
export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
@ -75,4 +75,5 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
||||
{ value: 'euler_a', label: 'Euler Ancestral' },
|
||||
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
||||
{ value: 'lcm', label: 'LCM' },
|
||||
{ value: 'tcd', label: 'TCD' },
|
||||
].sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user