diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fffb09e654..5905df8dd7 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation): with ( # apply all patches while the model is on the target device - text_encoder_info.model_on_device() as (model_state_dict, text_encoder), + text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder( text_encoder, loras=_lora_loader(), - model_state_dict=model_state_dict, + cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), @@ -175,13 +175,13 @@ class SDXLPromptInvocationBase: with ( # apply all patches while the model is on the target device - text_encoder_info.model_on_device() as (state_dict, text_encoder), + text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, ModelPatcher.apply_lora( text_encoder, loras=_lora_loader(), prefix=lora_prefix, - model_state_dict=state_dict, + cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index 089313463b..8db0b463ae 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput): title="Create Gradient Mask", tags=["mask", "denoise"], category="latents", - version="1.1.0", + version="1.2.0", ) class CreateGradientMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation): # redistribute blur so that the original edges are 0 and blur outwards to 1 blur_tensor = (blur_tensor - 0.5) * 2 + blur_tensor[blur_tensor < 0] = 0.0 threshold = 1 - self.minimum_denoise diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 560bc9003c..d97f92d42c 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt +from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt @@ -845,6 +846,16 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.unet.freeu_config: ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + ### lora + if self.unet.loras: + for lora_field in self.unet.loras: + ext_manager.add_extension( + LoRAExt( + node_context=context, + model_id=lora_field.lora, + weight=lora_field.weight, + ) + ) ### seamless if self.unet.seamless_axes: ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes)) @@ -964,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - unet_info.model_on_device() as (model_state_dict, unet), + unet_info.model_on_device() as (cached_weights, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet( unet, loras=_lora_loader(), - model_state_dict=model_state_dict, + cached_weights=cached_weights, ), ): assert isinstance(unet, UNet2DConditionModel) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 8ef81915f1..cec76ffea2 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -3,12 +3,13 @@ import bisect from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union import torch from safetensors.torch import load_file from typing_extensions import Self +import invokeai.backend.util.logging as logger from invokeai.backend.model_manager import BaseModelType from invokeai.backend.raw_model import RawModel @@ -46,9 +47,19 @@ class LoRALayerBase: self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() + def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + return self.bias + + def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]: + params = {"weight": self.get_weight(orig_module.weight)} + bias = self.get_bias(orig_module.bias) + if bias is not None: + params["bias"] = bias + return params + def calc_size(self) -> int: model_size = 0 for val in [self.bias]: @@ -60,6 +71,17 @@ class LoRALayerBase: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) + def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]): + """Log a warning if values contains unhandled keys.""" + # {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by + # `LoRALayerBase`. Sub-classes should provide the known_keys that they handled. + all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} + unknown_keys = set(values.keys()) - all_known_keys + if unknown_keys: + logger.warning( + f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}" + ) + # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): @@ -76,14 +98,19 @@ class LoRALayer(LoRALayerBase): self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] - else: - self.mid = None + self.mid = values.get("lora_mid.weight", None) self.rank = self.down.shape[0] + self.check_keys( + values, + { + "lora_up.weight", + "lora_down.weight", + "lora_mid.weight", + }, + ) - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -125,20 +152,23 @@ class LoHALayer(LoRALayerBase): self.w1_b = values["hada_w1_b"] self.w2_a = values["hada_w2_a"] self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1: Optional[torch.Tensor] = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2: Optional[torch.Tensor] = values["hada_t2"] - else: - self.t2 = None + self.t1 = values.get("hada_t1", None) + self.t2 = values.get("hada_t2", None) self.rank = self.w1_b.shape[0] + self.check_keys( + values, + { + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + }, + ) - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -186,37 +216,39 @@ class LoKRLayer(LoRALayerBase): ): super().__init__(layer_key, values) - if "lokr_w1" in values: - self.w1: Optional[torch.Tensor] = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None + self.w1 = values.get("lokr_w1", None) + if self.w1 is None: self.w1_a = values["lokr_w1_a"] self.w1_b = values["lokr_w1_b"] - if "lokr_w2" in values: - self.w2: Optional[torch.Tensor] = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None + self.w2 = values.get("lokr_w2", None) + if self.w2 is None: self.w2_a = values["lokr_w2_a"] self.w2_b = values["lokr_w2_b"] - if "lokr_t2" in values: - self.t2: Optional[torch.Tensor] = values["lokr_t2"] - else: - self.t2 = None + self.t2 = values.get("lokr_t2", None) - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] + if self.w1_b is not None: + self.rank = self.w1_b.shape[0] + elif self.w2_b is not None: + self.rank = self.w2_b.shape[0] else: self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + self.check_keys( + values, + { + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t2", + }, + ) + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None @@ -272,7 +304,9 @@ class LoKRLayer(LoRALayerBase): class FullLayer(LoRALayerBase): + # bias handled in LoRALayerBase(calc_size, to) # weight: torch.Tensor + # bias: Optional[torch.Tensor] def __init__( self, @@ -282,15 +316,12 @@ class FullLayer(LoRALayerBase): super().__init__(layer_key, values) self.weight = values["diff"] - - if len(values.keys()) > 1: - _keys = list(values.keys()) - _keys.remove("diff") - raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + self.bias = values.get("diff_b", None) self.rank = None # unscaled + self.check_keys(values, {"diff", "diff_b"}) - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: return self.weight def calc_size(self) -> int: @@ -319,8 +350,9 @@ class IA3Layer(LoRALayerBase): self.on_input = values["on_input"] self.rank = None # unscaled + self.check_keys(values, {"weight", "on_input"}) - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) @@ -458,16 +490,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) for layer_key, values in state_dict.items(): + # Detect layers according to LyCORIS detection logic(`weight_list_det`) + # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules + # lora and locon - if "lora_down.weight" in values: + if "lora_up.weight" in values: layer: AnyLoRALayer = LoRALayer(layer_key, values) # loha - elif "hada_w1_b" in values: + elif "hada_w1_a" in values: layer = LoHALayer(layer_key, values) # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: + elif "lokr_w1" in values or "lokr_w1_a" in values: layer = LoKRLayer(layer_key, values) # diff @@ -475,7 +510,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): layer = FullLayer(layer_key, values) # ia3 - elif "weight" in values and "on_input" in values: + elif "on_input" in values: layer = IA3Layer(layer_key, values) else: diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index d30f7b3167..e2f22ba019 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw -from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage """ loras = [ @@ -85,13 +86,13 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: with cls.apply_lora( unet, loras=loras, prefix="lora_unet_", - model_state_dict=model_state_dict, + cached_weights=cached_weights, ): yield @@ -101,9 +102,9 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: - with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): + with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights): yield @classmethod @@ -113,7 +114,7 @@ class ModelPatcher: model: AnyModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, ) -> Generator[None, None, None]: """ Apply one or more LoRAs to a model. @@ -121,66 +122,26 @@ class ModelPatcher: :param model: The model to patch. :param loras: An iterator that returns the LoRA to patch in and its patch weight. :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. + :cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes. """ - original_weights = {} + original_weights = OriginalWeightsStorage(cached_weights) try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue + for lora_model, lora_weight in loras: + LoRAExt.patch_model( + model=model, + prefix=prefix, + lora=lora_model, + lora_weight=lora_weight, + original_weights=original_weights, + ) + del lora_model - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - if model_state_dict is not None: # we were provided with the CPU copy of the state dict - original_weights[module_key] = model_state_dict[module_key + ".weight"] - else: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=TorchDevice.CPU_DEVICE) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - assert hasattr(layer_weight, "reshape") - layer_weight = layer_weight.reshape(module.weight.shape) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype) - - yield # wait for context manager exit + yield finally: - assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) + for param_key, weight in original_weights.get_changed_weights(): + model.get_parameter(param_key).copy_(weight) @classmethod @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 820d5d32a3..a3d27464a0 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,14 +2,14 @@ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List -import torch from diffusers import UNet2DConditionModel if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage @dataclass @@ -56,5 +56,17 @@ class ExtensionBase: yield None @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - yield None + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): + """A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire + diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by + `original_weights.save` function. Note that this enables some performance optimization by avoiding redundant + operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched + by this context manager. + + Args: + unet (UNet2DConditionModel): The UNet model on execution device to patch. + original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for + unpatching purposes. Extension should save tensor which being modified in this storage, also extensions + can access original weights values. + """ + yield diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 6ec4fea3fa..ff54e1a52f 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -1,15 +1,15 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING -import torch from diffusers import UNet2DConditionModel from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: from invokeai.app.shared.models import FreeUConfig + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage class FreeUExt(ExtensionBase): @@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase): self._freeu_config = freeu_config @contextmanager - def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): unet.enable_freeu( b1=self._freeu_config.b1, b2=self._freeu_config.b2, diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py new file mode 100644 index 0000000000..617bdcbbaf --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Tuple + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.lora import LoRAModelRaw + from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage + + +class LoRAExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + model_id: ModelIdentifierField, + weight: float, + ): + super().__init__() + self._node_context = node_context + self._model_id = model_id + self._weight = weight + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): + lora_model = self._node_context.models.load(self._model_id).model + self.patch_model( + model=unet, + prefix="lora_unet_", + lora=lora_model, + lora_weight=self._weight, + original_weights=original_weights, + ) + del lora_model + + yield + + @classmethod + @torch.no_grad() + def patch_model( + cls, + model: torch.nn.Module, + prefix: str, + lora: LoRAModelRaw, + lora_weight: float, + original_weights: OriginalWeightsStorage, + ): + """ + Apply one or more LoRAs to a model. + :param model: The model to patch. + :param lora: LoRA model to patch in. + :param lora_weight: LoRA patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching. + """ + + if lora_weight == 0: + return + + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + for param_name, lora_param_weight in layer.get_parameters(module).items(): + param_key = module_key + "." + param_name + module_param = module.get_parameter(param_name) + + # save original weight + original_weights.save(param_key, module_param) + + if module_param.shape != lora_param_weight.shape: + # TODO: debug on lycoris + lora_param_weight = lora_param_weight.reshape(module_param.shape) + + lora_param_weight *= lora_weight * layer_scale + module_param += lora_param_weight.to(dtype=dtype) + + layer.to(device=TorchDevice.CPU_DEVICE) + + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index c8d585406a..3783bb422e 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -7,6 +7,7 @@ import torch from diffusers import UNet2DConditionModel from invokeai.app.services.session_processor.session_processor_common import CanceledException +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext @@ -67,9 +68,15 @@ class ExtensionsManager: if self._is_canceled and self._is_canceled(): raise CanceledException - # TODO: create weight patch logic in PR with extension which uses it - with ExitStack() as exit_stack: - for ext in self._extensions: - exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + original_weights = OriginalWeightsStorage(cached_weights) + try: + with ExitStack() as exit_stack: + for ext in self._extensions: + exit_stack.enter_context(ext.patch_unet(unet, original_weights)) - yield None + yield None + + finally: + with torch.no_grad(): + for param_key, weight in original_weights.get_changed_weights(): + unet.get_parameter(param_key).copy_(weight) diff --git a/invokeai/backend/stable_diffusion/schedulers/schedulers.py b/invokeai/backend/stable_diffusion/schedulers/schedulers.py index 7d6851e278..c8836b316a 100644 --- a/invokeai/backend/stable_diffusion/schedulers/schedulers.py +++ b/invokeai/backend/stable_diffusion/schedulers/schedulers.py @@ -20,10 +20,14 @@ from diffusers import ( ) from diffusers.schedulers.scheduling_utils import SchedulerMixin +# TODO: add dpmpp_3s/dpmpp_3s_k when fix released +# https://github.com/huggingface/diffusers/issues/9007 + SCHEDULER_NAME_VALUES = Literal[ "ddim", "ddpm", "deis", + "deis_k", "lms", "lms_k", "pndm", @@ -33,16 +37,21 @@ SCHEDULER_NAME_VALUES = Literal[ "euler_k", "euler_a", "kdpm_2", + "kdpm_2_k", "kdpm_2_a", + "kdpm_2_a_k", "dpmpp_2s", "dpmpp_2s_k", "dpmpp_2m", "dpmpp_2m_k", "dpmpp_2m_sde", "dpmpp_2m_sde_k", + "dpmpp_3m", + "dpmpp_3m_k", "dpmpp_sde", "dpmpp_sde_k", "unipc", + "unipc_k", "lcm", "tcd", ] @@ -50,7 +59,8 @@ SCHEDULER_NAME_VALUES = Literal[ SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = { "ddim": (DDIMScheduler, {}), "ddpm": (DDPMScheduler, {}), - "deis": (DEISMultistepScheduler, {}), + "deis": (DEISMultistepScheduler, {"use_karras_sigmas": False}), + "deis_k": (DEISMultistepScheduler, {"use_karras_sigmas": True}), "lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}), "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), "pndm": (PNDMScheduler, {}), @@ -59,17 +69,28 @@ SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, "euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}), "euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}), "euler_a": (EulerAncestralDiscreteScheduler, {}), - "kdpm_2": (KDPM2DiscreteScheduler, {}), - "kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}), - "dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}), - "dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}), - "dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}), - "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}), - "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}), - "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}), + "kdpm_2": (KDPM2DiscreteScheduler, {"use_karras_sigmas": False}), + "kdpm_2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}), + "kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": False}), + "kdpm_2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}), + "dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False, "solver_order": 2}), + "dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "solver_order": 2}), + "dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 2}), + "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2}), + "dpmpp_2m_sde": ( + DPMSolverMultistepScheduler, + {"use_karras_sigmas": False, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"}, + ), + "dpmpp_2m_sde_k": ( + DPMSolverMultistepScheduler, + {"use_karras_sigmas": True, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"}, + ), + "dpmpp_3m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 3}), + "dpmpp_3m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 3}), "dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}), "dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}), - "unipc": (UniPCMultistepScheduler, {"cpu_only": True}), + "unipc": (UniPCMultistepScheduler, {"use_karras_sigmas": False, "cpu_only": True}), + "unipc_k": (UniPCMultistepScheduler, {"use_karras_sigmas": True, "cpu_only": True}), "lcm": (LCMScheduler, {}), "tcd": (TCDScheduler, {}), } diff --git a/invokeai/backend/util/original_weights_storage.py b/invokeai/backend/util/original_weights_storage.py new file mode 100644 index 0000000000..af945b086f --- /dev/null +++ b/invokeai/backend/util/original_weights_storage.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Dict, Iterator, Optional, Tuple + +import torch + +from invokeai.backend.util.devices import TorchDevice + + +class OriginalWeightsStorage: + """A class for tracking the original weights of a model for patch/unpatch operations.""" + + def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + # The original weights of the model. + self._weights: dict[str, torch.Tensor] = {} + # The keys of the weights that have been changed (via `save()`) during the lifetime of this instance. + self._changed_weights: set[str] = set() + if cached_weights: + self._weights.update(cached_weights) + + def save(self, key: str, weight: torch.Tensor, copy: bool = True): + self._changed_weights.add(key) + if key in self._weights: + return + + self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy) + + def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]: + weight = self._weights.get(key, None) + if weight is not None and copy: + weight = weight.clone() + return weight + + def contains(self, key: str) -> bool: + return key in self._weights + + def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]: + for key in self._changed_weights: + yield key, self._weights[key] diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 659df78d9b..3300f7c7fa 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -31,7 +31,8 @@ "deleteBoard": "Delete Board", "deleteBoardAndImages": "Delete Board and Images", "deleteBoardOnly": "Delete Board Only", - "deletedBoardsCannotbeRestored": "Deleted boards cannot be restored", + "deletedBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to an uncategorized state.", + "deletedPrivateBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to a private uncategorized state for the image's creator.", "hideBoards": "Hide Boards", "loading": "Loading...", "menuItemAutoAdd": "Auto-add to this Board", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 2d878d96e7..760eddbee8 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -16,6 +16,8 @@ import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterM import { configChanged } from 'features/system/store/configSlice'; import { languageSelector } from 'features/system/store/systemSelectors'; import InvokeTabs from 'features/ui/components/InvokeTabs'; +import type { InvokeTabName } from 'features/ui/store/tabMap'; +import { setActiveTab } from 'features/ui/store/uiSlice'; import { AnimatePresence } from 'framer-motion'; import i18n from 'i18n'; import { size } from 'lodash-es'; @@ -34,9 +36,10 @@ interface Props { imageName: string; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; }; + destination?: InvokeTabName | undefined; } -const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { +const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => { const language = useAppSelector(languageSelector); const logger = useLogger('system'); const dispatch = useAppDispatch(); @@ -67,6 +70,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { } }, [dispatch, config, logger]); + useEffect(() => { + if (destination) { + dispatch(setActiveTab(destination)); + } + }, [dispatch, destination]); + useEffect(() => { dispatch(appStarted()); }, [dispatch]); diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 1dd1a265fb..0a80b7e92d 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -19,6 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai'; import Loading from 'common/components/Loading/Loading'; import AppDndContext from 'features/dnd/components/AppDndContext'; import type { WorkflowCategory } from 'features/nodes/types/workflow'; +import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { PropsWithChildren, ReactNode } from 'react'; import React, { lazy, memo, useEffect, useMemo } from 'react'; import { Provider } from 'react-redux'; @@ -43,6 +44,7 @@ interface Props extends PropsWithChildren { imageName: string; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; }; + destination?: InvokeTabName; customStarUi?: CustomStarUi; socketOptions?: Partial; isDebugging?: boolean; @@ -62,6 +64,7 @@ const InvokeAIUI = ({ projectUrl, queueId, selectedImage, + destination, customStarUi, socketOptions, isDebugging = false, @@ -218,7 +221,7 @@ const InvokeAIUI = ({ }> - + diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx index 377636d0d0..3707c24440 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/DeleteBoardModal.tsx @@ -120,7 +120,11 @@ const DeleteBoardModal = (props: Props) => { bottomMessage={t('boards.bottomMessage')} /> )} - {t('boards.deletedBoardsCannotbeRestored')} + + {boardToDelete.is_private + ? t('boards.deletedPrivateBoardsCannotbeRestored') + : t('boards.deletedBoardsCannotbeRestored')} + {canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')} diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 2ea8900281..c84b2dae62 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -32,6 +32,7 @@ export const zSchedulerField = z.enum([ 'ddpm', 'dpmpp_2s', 'dpmpp_2m', + 'dpmpp_3m', 'dpmpp_2m_sde', 'dpmpp_sde', 'heun', @@ -40,12 +41,17 @@ export const zSchedulerField = z.enum([ 'pndm', 'unipc', 'euler_k', + 'deis_k', 'dpmpp_2s_k', 'dpmpp_2m_k', + 'dpmpp_3m_k', 'dpmpp_2m_sde_k', 'dpmpp_sde_k', 'heun_k', + 'kdpm_2_k', + 'kdpm_2_a_k', 'lms_k', + 'unipc_k', 'euler_a', 'kdpm_2_a', 'lcm', diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index 6d7b4f9248..678b2b37f3 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -52,28 +52,34 @@ export const CLIP_SKIP_MAP = { * Mapping of schedulers to human readable name */ export const SCHEDULER_OPTIONS: ComboboxOption[] = [ - { value: 'euler', label: 'Euler' }, - { value: 'deis', label: 'DEIS' }, { value: 'ddim', label: 'DDIM' }, { value: 'ddpm', label: 'DDPM' }, - { value: 'dpmpp_sde', label: 'DPM++ SDE' }, + { value: 'deis', label: 'DEIS' }, + { value: 'deis_k', label: 'DEIS Karras' }, { value: 'dpmpp_2s', label: 'DPM++ 2S' }, - { value: 'dpmpp_2m', label: 'DPM++ 2M' }, - { value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' }, - { value: 'heun', label: 'Heun' }, - { value: 'kdpm_2', label: 'KDPM 2' }, - { value: 'lms', label: 'LMS' }, - { value: 'pndm', label: 'PNDM' }, - { value: 'unipc', label: 'UniPC' }, - { value: 'euler_k', label: 'Euler Karras' }, - { value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' }, { value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' }, + { value: 'dpmpp_2m', label: 'DPM++ 2M' }, { value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' }, + { value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' }, { value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' }, - { value: 'heun_k', label: 'Heun Karras' }, - { value: 'lms_k', label: 'LMS Karras' }, + { value: 'dpmpp_3m', label: 'DPM++ 3M' }, + { value: 'dpmpp_3m_k', label: 'DPM++ 3M Karras' }, + { value: 'dpmpp_sde', label: 'DPM++ SDE' }, + { value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' }, + { value: 'euler', label: 'Euler' }, + { value: 'euler_k', label: 'Euler Karras' }, { value: 'euler_a', label: 'Euler Ancestral' }, + { value: 'heun', label: 'Heun' }, + { value: 'heun_k', label: 'Heun Karras' }, + { value: 'kdpm_2', label: 'KDPM 2' }, + { value: 'kdpm_2_k', label: 'KDPM 2 Karras' }, { value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' }, + { value: 'kdpm_2_a_k', label: 'KDPM 2 Ancestral Karras' }, { value: 'lcm', label: 'LCM' }, + { value: 'lms', label: 'LMS' }, + { value: 'lms_k', label: 'LMS Karras' }, + { value: 'pndm', label: 'PNDM' }, { value: 'tcd', label: 'TCD' }, -].sort((a, b) => a.label.localeCompare(b.label)); + { value: 'unipc', label: 'UniPC' }, + { value: 'unipc_k', label: 'UniPC Karras' }, +]; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 59f9897f74..79b82a23fa 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -3553,7 +3553,7 @@ export type components = { * @default euler * @enum {string} */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd"; + scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd"; /** * UNet * @description UNet (scheduler, LoRAs) @@ -8553,7 +8553,7 @@ export type components = { * Scheduler * @description Default scheduler for this model */ - scheduler?: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd") | null; + scheduler?: ("ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd") | null; /** * Steps * @description Default number of steps for this model @@ -11467,7 +11467,7 @@ export type components = { * @default euler * @enum {string} */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd"; + scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd"; /** * type * @default scheduler @@ -11483,7 +11483,7 @@ export type components = { * @description Scheduler to use during inference * @enum {string} */ - scheduler: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd"; + scheduler: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd"; /** * type * @default scheduler_output @@ -13261,7 +13261,7 @@ export type components = { * @default euler * @enum {string} */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd"; + scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd"; /** * UNet * @description UNet (scheduler, LoRAs)