mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker7779/fix_gradient_mask
This commit is contained in:
commit
fa3c0c81b3
@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
FROM node:20-slim AS web-builder
|
FROM node:20-slim AS web-builder
|
||||||
ENV PNPM_HOME="/pnpm"
|
ENV PNPM_HOME="/pnpm"
|
||||||
ENV PATH="$PNPM_HOME:$PATH"
|
ENV PATH="$PNPM_HOME:$PATH"
|
||||||
|
RUN corepack use pnpm@8.x
|
||||||
RUN corepack enable
|
RUN corepack enable
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
@ -60,8 +60,12 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
|||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
@ -498,6 +502,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_t2i_adapter_field(
|
||||||
|
exit_stack: ExitStack,
|
||||||
|
context: InvocationContext,
|
||||||
|
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||||
|
ext_manager: ExtensionsManager,
|
||||||
|
) -> None:
|
||||||
|
if t2i_adapters is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||||
|
if isinstance(t2i_adapters, T2IAdapterField):
|
||||||
|
t2i_adapters = [t2i_adapters]
|
||||||
|
|
||||||
|
for t2i_adapter_field in t2i_adapters:
|
||||||
|
ext_manager.add_extension(
|
||||||
|
T2IAdapterExt(
|
||||||
|
node_context=context,
|
||||||
|
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||||
|
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||||
|
weight=t2i_adapter_field.weight,
|
||||||
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||||
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||||
|
resize_mode=t2i_adapter_field.resize_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prep_ip_adapter_image_prompts(
|
def prep_ip_adapter_image_prompts(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -707,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||||
|
|
||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_noise_and_latents(
|
def prepare_noise_and_latents(
|
||||||
@ -765,10 +796,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
|
||||||
if noise is not None:
|
|
||||||
noise = noise.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
@ -801,21 +828,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
denoise_ctx = DenoiseContext(
|
|
||||||
inputs=DenoiseInputs(
|
|
||||||
orig_latents=latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
attention_processor_cls=CustomAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
unet=None,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
@ -833,6 +845,40 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.unet.freeu_config:
|
if self.unet.freeu_config:
|
||||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||||
|
|
||||||
|
### seamless
|
||||||
|
if self.unet.seamless_axes:
|
||||||
|
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||||
|
|
||||||
|
### inpaint
|
||||||
|
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||||
|
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||||
|
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||||
|
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||||
|
if unet_config.variant == ModelVariantType.Inpaint:
|
||||||
|
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||||
|
elif mask is not None:
|
||||||
|
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||||
|
|
||||||
|
# Initialize context for modular denoise
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
# context for loading additional models
|
# context for loading additional models
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
# later should be smth like:
|
# later should be smth like:
|
||||||
@ -840,6 +886,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||||
# ext_manager.add_extension(ext)
|
# ext_manager.add_extension(ext)
|
||||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||||
|
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
@ -871,6 +918,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||||
|
# We invert the mask here for compatibility with the old backend implementation.
|
||||||
|
if mask is not None:
|
||||||
|
mask = 1 - mask
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
@ -915,7 +966,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
|
@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion import set_seamless
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
|
@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PipelineIntermediateState",
|
"PipelineIntermediateState",
|
||||||
"StableDiffusionGeneratorPipeline",
|
"StableDiffusionGeneratorPipeline",
|
||||||
"InvokeAIDiffuserComponent",
|
"InvokeAIDiffuserComponent",
|
||||||
"set_seamless",
|
|
||||||
]
|
]
|
||||||
|
120
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
"""Initialize InpaintExt.
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._mask = mask
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
# Noise, which used to noisify unmasked part of image
|
||||||
|
# if noise provided to context, then it will be used
|
||||||
|
# if no noise provided, then noise will be generated based on seed
|
||||||
|
self._noise: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_normal_model(unet: UNet2DConditionModel):
|
||||||
|
"""Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
|
return unet.conv_in.in_channels == 4
|
||||||
|
|
||||||
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if t.dim() == 0:
|
||||||
|
# some schedulers expect t to be one-dimensional.
|
||||||
|
# TODO: file diffusers bug about inconsistency?
|
||||||
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
|
# get very confused about what is happening from step to step when we do that.
|
||||||
|
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
|
mask_bool = mask < 1 - threshold
|
||||||
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
|
else:
|
||||||
|
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
|
||||||
|
return masked_input
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_normal_model(ctx.unet):
|
||||||
|
raise ValueError(
|
||||||
|
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
|
||||||
|
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
|
||||||
|
"fixed by removing and re-adding the model (so that it gets re-probed)."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
self._noise = ctx.inputs.noise
|
||||||
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
|
# We still need noise for inpainting, so we generate it from the seed here.
|
||||||
|
if self._noise is None:
|
||||||
|
self._noise = torch.randn(
|
||||||
|
ctx.latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
||||||
|
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# Use negative order to make extensions with default order work with patched latents
|
||||||
|
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
||||||
|
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
||||||
|
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
# TODO: redo this with preview events rewrite
|
||||||
|
# Use negative order to make extensions with default order work with patched latents
|
||||||
|
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
||||||
|
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
||||||
|
timestep = ctx.scheduler.timesteps[-1]
|
||||||
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
|
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
||||||
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
||||||
|
else:
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
||||||
|
|
||||||
|
# Restore unmasked part after the last step is completed
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
|
@ -0,0 +1,88 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintModelExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
masked_latents: Optional[torch.Tensor],
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
"""Initialize InpaintModelExt.
|
||||||
|
Args:
|
||||||
|
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||||
|
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if mask is not None and masked_latents is None:
|
||||||
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
|
|
||||||
|
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
|
||||||
|
self._mask = None
|
||||||
|
if mask is not None:
|
||||||
|
self._mask = 1 - mask
|
||||||
|
self._masked_latents = masked_latents
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||||
|
"""Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
|
return unet.conv_in.in_channels == 9
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_inpaint_model(ctx.unet):
|
||||||
|
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
||||||
|
|
||||||
|
if self._mask is None:
|
||||||
|
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||||
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
if self._masked_latents is None:
|
||||||
|
self._masked_latents = torch.zeros_like(ctx.latents[:1])
|
||||||
|
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# Do last so that other extensions works with normal latents
|
||||||
|
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
||||||
|
def append_inpaint_layers(self, ctx: DenoiseContext):
|
||||||
|
batch_size = ctx.unet_kwargs.sample.shape[0]
|
||||||
|
b_mask = torch.cat([self._mask] * batch_size)
|
||||||
|
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
|
||||||
|
ctx.unet_kwargs.sample = torch.cat(
|
||||||
|
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore unmasked part as inpaint model can change unmasked part slightly
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.models.lora import LoRACompatibleConv
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
|
|
||||||
|
class SeamlessExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seamless_axes: List[str],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._seamless_axes = seamless_axes
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
with self.static_patch_model(
|
||||||
|
model=unet,
|
||||||
|
seamless_axes=self._seamless_axes,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def static_patch_model(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
seamless_axes: List[str],
|
||||||
|
):
|
||||||
|
if not seamless_axes:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||||
|
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||||
|
|
||||||
|
# override conv_forward
|
||||||
|
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||||
|
def _conv_forward_asymmetric(
|
||||||
|
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||||
|
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||||
|
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||||
|
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||||
|
return torch.nn.functional.conv2d(
|
||||||
|
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||||
|
)
|
||||||
|
|
||||||
|
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||||
|
try:
|
||||||
|
for layer in model.modules():
|
||||||
|
if not isinstance(layer, torch.nn.Conv2d):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||||
|
layer.lora_layer = lambda *x: 0
|
||||||
|
original_layers.append((layer, layer._conv_forward))
|
||||||
|
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for layer, orig_conv_forward in original_layers:
|
||||||
|
layer._conv_forward = orig_conv_forward
|
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import T2IAdapter
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_context: InvocationContext,
|
||||||
|
model_id: ModelIdentifierField,
|
||||||
|
image: Image,
|
||||||
|
weight: Union[float, List[float]],
|
||||||
|
begin_step_percent: float,
|
||||||
|
end_step_percent: float,
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._node_context = node_context
|
||||||
|
self._model_id = model_id
|
||||||
|
self._image = image
|
||||||
|
self._weight = weight
|
||||||
|
self._resize_mode = resize_mode
|
||||||
|
self._begin_step_percent = begin_step_percent
|
||||||
|
self._end_step_percent = end_step_percent
|
||||||
|
|
||||||
|
self._adapter_state: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
|
model_config = self._node_context.models.get_config(self._model_id.key)
|
||||||
|
if model_config.base == BaseModelType.StableDiffusion1:
|
||||||
|
self._max_unet_downscale = 8
|
||||||
|
elif model_config.base == BaseModelType.StableDiffusionXL:
|
||||||
|
self._max_unet_downscale = 4
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.SETUP)
|
||||||
|
def setup(self, ctx: DenoiseContext):
|
||||||
|
t2i_model: T2IAdapter
|
||||||
|
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||||
|
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||||
|
|
||||||
|
self._adapter_state = self._run_model(
|
||||||
|
model=t2i_model,
|
||||||
|
image=self._image,
|
||||||
|
latents_height=latents_height,
|
||||||
|
latents_width=latents_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_model(
|
||||||
|
self,
|
||||||
|
model: T2IAdapter,
|
||||||
|
image: Image,
|
||||||
|
latents_height: int,
|
||||||
|
latents_width: int,
|
||||||
|
):
|
||||||
|
# Resize the T2I-Adapter input image.
|
||||||
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||||
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||||
|
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
|
||||||
|
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
|
||||||
|
|
||||||
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||||
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||||
|
# T2I-Adapter model.
|
||||||
|
#
|
||||||
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||||
|
# of the same requirements (e.g. preserving binary masks during resize).
|
||||||
|
t2i_image = prepare_control_image(
|
||||||
|
image=image,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
width=input_width,
|
||||||
|
height=input_height,
|
||||||
|
num_channels=model.config["in_channels"],
|
||||||
|
device=model.device,
|
||||||
|
dtype=model.dtype,
|
||||||
|
resize_mode=self._resize_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model(t2i_image)
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_UNET)
|
||||||
|
def pre_unet_step(self, ctx: DenoiseContext):
|
||||||
|
# skip if model not active in current step
|
||||||
|
total_steps = len(ctx.inputs.timesteps)
|
||||||
|
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||||
|
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||||
|
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||||
|
return
|
||||||
|
|
||||||
|
weight = self._weight
|
||||||
|
if isinstance(weight, list):
|
||||||
|
weight = weight[ctx.step_index]
|
||||||
|
|
||||||
|
adapter_state = self._adapter_state
|
||||||
|
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||||
|
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
|
||||||
|
|
||||||
|
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
|
||||||
|
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
|
||||||
|
else:
|
||||||
|
for i, value in enumerate(adapter_state):
|
||||||
|
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
|
@ -20,10 +20,14 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
# TODO: add dpmpp_3s/dpmpp_3s_k when fix released
|
||||||
|
# https://github.com/huggingface/diffusers/issues/9007
|
||||||
|
|
||||||
SCHEDULER_NAME_VALUES = Literal[
|
SCHEDULER_NAME_VALUES = Literal[
|
||||||
"ddim",
|
"ddim",
|
||||||
"ddpm",
|
"ddpm",
|
||||||
"deis",
|
"deis",
|
||||||
|
"deis_k",
|
||||||
"lms",
|
"lms",
|
||||||
"lms_k",
|
"lms_k",
|
||||||
"pndm",
|
"pndm",
|
||||||
@ -33,16 +37,21 @@ SCHEDULER_NAME_VALUES = Literal[
|
|||||||
"euler_k",
|
"euler_k",
|
||||||
"euler_a",
|
"euler_a",
|
||||||
"kdpm_2",
|
"kdpm_2",
|
||||||
|
"kdpm_2_k",
|
||||||
"kdpm_2_a",
|
"kdpm_2_a",
|
||||||
|
"kdpm_2_a_k",
|
||||||
"dpmpp_2s",
|
"dpmpp_2s",
|
||||||
"dpmpp_2s_k",
|
"dpmpp_2s_k",
|
||||||
"dpmpp_2m",
|
"dpmpp_2m",
|
||||||
"dpmpp_2m_k",
|
"dpmpp_2m_k",
|
||||||
"dpmpp_2m_sde",
|
"dpmpp_2m_sde",
|
||||||
"dpmpp_2m_sde_k",
|
"dpmpp_2m_sde_k",
|
||||||
|
"dpmpp_3m",
|
||||||
|
"dpmpp_3m_k",
|
||||||
"dpmpp_sde",
|
"dpmpp_sde",
|
||||||
"dpmpp_sde_k",
|
"dpmpp_sde_k",
|
||||||
"unipc",
|
"unipc",
|
||||||
|
"unipc_k",
|
||||||
"lcm",
|
"lcm",
|
||||||
"tcd",
|
"tcd",
|
||||||
]
|
]
|
||||||
@ -50,7 +59,8 @@ SCHEDULER_NAME_VALUES = Literal[
|
|||||||
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
|
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
|
||||||
"ddim": (DDIMScheduler, {}),
|
"ddim": (DDIMScheduler, {}),
|
||||||
"ddpm": (DDPMScheduler, {}),
|
"ddpm": (DDPMScheduler, {}),
|
||||||
"deis": (DEISMultistepScheduler, {}),
|
"deis": (DEISMultistepScheduler, {"use_karras_sigmas": False}),
|
||||||
|
"deis_k": (DEISMultistepScheduler, {"use_karras_sigmas": True}),
|
||||||
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"pndm": (PNDMScheduler, {}),
|
"pndm": (PNDMScheduler, {}),
|
||||||
@ -59,17 +69,28 @@ SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str,
|
|||||||
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||||
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
"kdpm_2": (KDPM2DiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
"kdpm_2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
|
"kdpm_2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
|
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
|
||||||
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
|
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
|
||||||
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
|
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
|
||||||
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
|
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
|
||||||
|
"dpmpp_2m_sde": (
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
{"use_karras_sigmas": False, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
|
||||||
|
),
|
||||||
|
"dpmpp_2m_sde_k": (
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
{"use_karras_sigmas": True, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
|
||||||
|
),
|
||||||
|
"dpmpp_3m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 3}),
|
||||||
|
"dpmpp_3m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 3}),
|
||||||
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
||||||
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
||||||
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
|
"unipc": (UniPCMultistepScheduler, {"use_karras_sigmas": False, "cpu_only": True}),
|
||||||
|
"unipc_k": (UniPCMultistepScheduler, {"use_karras_sigmas": True, "cpu_only": True}),
|
||||||
"lcm": (LCMScheduler, {}),
|
"lcm": (LCMScheduler, {}),
|
||||||
"tcd": (TCDScheduler, {}),
|
"tcd": (TCDScheduler, {}),
|
||||||
}
|
}
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
|
||||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
|
||||||
from diffusers.models.lora import LoRACompatibleConv
|
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
|
||||||
if not seamless_axes:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
# override conv_forward
|
|
||||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
|
||||||
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
|
||||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
|
||||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
|
||||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
|
||||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
|
||||||
return torch.nn.functional.conv2d(
|
|
||||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
|
||||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
|
||||||
|
|
||||||
conv_layers: List[torch.nn.Conv2d] = []
|
|
||||||
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, torch.nn.Conv2d):
|
|
||||||
conv_layers.append(module)
|
|
||||||
|
|
||||||
for layer in conv_layers:
|
|
||||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
|
||||||
layer.lora_layer = lambda *x: 0
|
|
||||||
original_layers.append((layer, layer._conv_forward))
|
|
||||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for layer, orig_conv_forward in original_layers:
|
|
||||||
layer._conv_forward = orig_conv_forward
|
|
@ -31,7 +31,8 @@
|
|||||||
"deleteBoard": "Delete Board",
|
"deleteBoard": "Delete Board",
|
||||||
"deleteBoardAndImages": "Delete Board and Images",
|
"deleteBoardAndImages": "Delete Board and Images",
|
||||||
"deleteBoardOnly": "Delete Board Only",
|
"deleteBoardOnly": "Delete Board Only",
|
||||||
"deletedBoardsCannotbeRestored": "Deleted boards cannot be restored",
|
"deletedBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to an uncategorized state.",
|
||||||
|
"deletedPrivateBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to a private uncategorized state for the image's creator.",
|
||||||
"hideBoards": "Hide Boards",
|
"hideBoards": "Hide Boards",
|
||||||
"loading": "Loading...",
|
"loading": "Loading...",
|
||||||
"menuItemAutoAdd": "Auto-add to this Board",
|
"menuItemAutoAdd": "Auto-add to this Board",
|
||||||
|
@ -10,32 +10,32 @@ import {
|
|||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
|
// Type inference doesn't work for this if you inline it in the listener for some reason
|
||||||
|
const matchAnyBoardDeleted = isAnyOf(
|
||||||
|
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||||
|
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
||||||
|
);
|
||||||
|
|
||||||
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
||||||
/**
|
/**
|
||||||
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
||||||
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
||||||
*/
|
*/
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: isAnyOf(
|
matcher: matchAnyBoardDeleted,
|
||||||
// If a board is deleted, we'll need to reset the auto-add board
|
|
||||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
|
||||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
|
||||||
),
|
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListBoardsQueryArgs(state);
|
const deletedBoardId = action.meta.arg.originalArgs;
|
||||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
|
||||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||||
|
|
||||||
if (!queryResult.data) {
|
// If the deleted board was currently selected, we should reset the selected board to uncategorized
|
||||||
return;
|
if (deletedBoardId === selectedBoardId) {
|
||||||
}
|
|
||||||
|
|
||||||
if (!queryResult.data.find((board) => board.board_id === selectedBoardId)) {
|
|
||||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||||
dispatch(galleryViewChanged('images'));
|
dispatch(galleryViewChanged('images'));
|
||||||
}
|
}
|
||||||
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
|
|
||||||
|
// If the deleted board was selected for auto-add, we should reset the auto-add board to uncategorized
|
||||||
|
if (deletedBoardId === autoAddBoardId) {
|
||||||
dispatch(autoAddBoardIdChanged('none'));
|
dispatch(autoAddBoardIdChanged('none'));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -46,14 +46,8 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListBoardsQueryArgs(state);
|
|
||||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
|
||||||
const { shouldShowArchivedBoards } = state.gallery;
|
const { shouldShowArchivedBoards } = state.gallery;
|
||||||
|
|
||||||
if (!queryResult.data) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
|
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
|
||||||
|
|
||||||
if (wasArchived && !shouldShowArchivedBoards) {
|
if (wasArchived && !shouldShowArchivedBoards) {
|
||||||
@ -71,7 +65,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
const shouldShowArchivedBoards = action.payload;
|
const shouldShowArchivedBoards = action.payload;
|
||||||
|
|
||||||
// We only need to take action if we have just hidden archived boards.
|
// We only need to take action if we have just hidden archived boards.
|
||||||
if (!shouldShowArchivedBoards) {
|
if (shouldShowArchivedBoards) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,14 +80,16 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
|
|
||||||
// Handle the case where selected board is archived
|
// Handle the case where selected board is archived
|
||||||
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
|
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
|
||||||
if (selectedBoard && selectedBoard.archived) {
|
if (!selectedBoard || selectedBoard.archived) {
|
||||||
|
// If we can't find the selected board or it's archived, we should reset the selected board to uncategorized
|
||||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||||
dispatch(galleryViewChanged('images'));
|
dispatch(galleryViewChanged('images'));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the case where auto-add board is archived
|
// Handle the case where auto-add board is archived
|
||||||
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
|
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
|
||||||
if (autoAddBoard && autoAddBoard.archived) {
|
if (!autoAddBoard || autoAddBoard.archived) {
|
||||||
|
// If we can't find the auto-add board or it's archived, we should reset the selected board to uncategorized
|
||||||
dispatch(autoAddBoardIdChanged('none'));
|
dispatch(autoAddBoardIdChanged('none'));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -120,7 +120,11 @@ const DeleteBoardModal = (props: Props) => {
|
|||||||
bottomMessage={t('boards.bottomMessage')}
|
bottomMessage={t('boards.bottomMessage')}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<Text>{t('boards.deletedBoardsCannotbeRestored')}</Text>
|
<Text>
|
||||||
|
{boardToDelete.is_private
|
||||||
|
? t('boards.deletedPrivateBoardsCannotbeRestored')
|
||||||
|
: t('boards.deletedBoardsCannotbeRestored')}
|
||||||
|
</Text>
|
||||||
<Text>
|
<Text>
|
||||||
{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}
|
{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}
|
||||||
</Text>
|
</Text>
|
||||||
|
@ -32,6 +32,7 @@ export const zSchedulerField = z.enum([
|
|||||||
'ddpm',
|
'ddpm',
|
||||||
'dpmpp_2s',
|
'dpmpp_2s',
|
||||||
'dpmpp_2m',
|
'dpmpp_2m',
|
||||||
|
'dpmpp_3m',
|
||||||
'dpmpp_2m_sde',
|
'dpmpp_2m_sde',
|
||||||
'dpmpp_sde',
|
'dpmpp_sde',
|
||||||
'heun',
|
'heun',
|
||||||
@ -40,12 +41,17 @@ export const zSchedulerField = z.enum([
|
|||||||
'pndm',
|
'pndm',
|
||||||
'unipc',
|
'unipc',
|
||||||
'euler_k',
|
'euler_k',
|
||||||
|
'deis_k',
|
||||||
'dpmpp_2s_k',
|
'dpmpp_2s_k',
|
||||||
'dpmpp_2m_k',
|
'dpmpp_2m_k',
|
||||||
|
'dpmpp_3m_k',
|
||||||
'dpmpp_2m_sde_k',
|
'dpmpp_2m_sde_k',
|
||||||
'dpmpp_sde_k',
|
'dpmpp_sde_k',
|
||||||
'heun_k',
|
'heun_k',
|
||||||
|
'kdpm_2_k',
|
||||||
|
'kdpm_2_a_k',
|
||||||
'lms_k',
|
'lms_k',
|
||||||
|
'unipc_k',
|
||||||
'euler_a',
|
'euler_a',
|
||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
'lcm',
|
'lcm',
|
||||||
|
@ -52,28 +52,34 @@ export const CLIP_SKIP_MAP = {
|
|||||||
* Mapping of schedulers to human readable name
|
* Mapping of schedulers to human readable name
|
||||||
*/
|
*/
|
||||||
export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
||||||
{ value: 'euler', label: 'Euler' },
|
|
||||||
{ value: 'deis', label: 'DEIS' },
|
|
||||||
{ value: 'ddim', label: 'DDIM' },
|
{ value: 'ddim', label: 'DDIM' },
|
||||||
{ value: 'ddpm', label: 'DDPM' },
|
{ value: 'ddpm', label: 'DDPM' },
|
||||||
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
|
{ value: 'deis', label: 'DEIS' },
|
||||||
|
{ value: 'deis_k', label: 'DEIS Karras' },
|
||||||
{ value: 'dpmpp_2s', label: 'DPM++ 2S' },
|
{ value: 'dpmpp_2s', label: 'DPM++ 2S' },
|
||||||
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
|
|
||||||
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
|
|
||||||
{ value: 'heun', label: 'Heun' },
|
|
||||||
{ value: 'kdpm_2', label: 'KDPM 2' },
|
|
||||||
{ value: 'lms', label: 'LMS' },
|
|
||||||
{ value: 'pndm', label: 'PNDM' },
|
|
||||||
{ value: 'unipc', label: 'UniPC' },
|
|
||||||
{ value: 'euler_k', label: 'Euler Karras' },
|
|
||||||
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
|
|
||||||
{ value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' },
|
{ value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' },
|
||||||
|
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
|
||||||
{ value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' },
|
{ value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' },
|
||||||
|
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
|
||||||
{ value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' },
|
{ value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' },
|
||||||
{ value: 'heun_k', label: 'Heun Karras' },
|
{ value: 'dpmpp_3m', label: 'DPM++ 3M' },
|
||||||
{ value: 'lms_k', label: 'LMS Karras' },
|
{ value: 'dpmpp_3m_k', label: 'DPM++ 3M Karras' },
|
||||||
|
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
|
||||||
|
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
|
||||||
|
{ value: 'euler', label: 'Euler' },
|
||||||
|
{ value: 'euler_k', label: 'Euler Karras' },
|
||||||
{ value: 'euler_a', label: 'Euler Ancestral' },
|
{ value: 'euler_a', label: 'Euler Ancestral' },
|
||||||
|
{ value: 'heun', label: 'Heun' },
|
||||||
|
{ value: 'heun_k', label: 'Heun Karras' },
|
||||||
|
{ value: 'kdpm_2', label: 'KDPM 2' },
|
||||||
|
{ value: 'kdpm_2_k', label: 'KDPM 2 Karras' },
|
||||||
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
||||||
|
{ value: 'kdpm_2_a_k', label: 'KDPM 2 Ancestral Karras' },
|
||||||
{ value: 'lcm', label: 'LCM' },
|
{ value: 'lcm', label: 'LCM' },
|
||||||
|
{ value: 'lms', label: 'LMS' },
|
||||||
|
{ value: 'lms_k', label: 'LMS Karras' },
|
||||||
|
{ value: 'pndm', label: 'PNDM' },
|
||||||
{ value: 'tcd', label: 'TCD' },
|
{ value: 'tcd', label: 'TCD' },
|
||||||
].sort((a, b) => a.label.localeCompare(b.label));
|
{ value: 'unipc', label: 'UniPC' },
|
||||||
|
{ value: 'unipc_k', label: 'UniPC Karras' },
|
||||||
|
];
|
||||||
|
@ -3553,7 +3553,7 @@ export type components = {
|
|||||||
* @default euler
|
* @default euler
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* UNet
|
* UNet
|
||||||
* @description UNet (scheduler, LoRAs)
|
* @description UNet (scheduler, LoRAs)
|
||||||
@ -8553,7 +8553,7 @@ export type components = {
|
|||||||
* Scheduler
|
* Scheduler
|
||||||
* @description Default scheduler for this model
|
* @description Default scheduler for this model
|
||||||
*/
|
*/
|
||||||
scheduler?: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd") | null;
|
scheduler?: ("ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd") | null;
|
||||||
/**
|
/**
|
||||||
* Steps
|
* Steps
|
||||||
* @description Default number of steps for this model
|
* @description Default number of steps for this model
|
||||||
@ -11467,7 +11467,7 @@ export type components = {
|
|||||||
* @default euler
|
* @default euler
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default scheduler
|
* @default scheduler
|
||||||
@ -11483,7 +11483,7 @@ export type components = {
|
|||||||
* @description Scheduler to use during inference
|
* @description Scheduler to use during inference
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default scheduler_output
|
* @default scheduler_output
|
||||||
@ -13261,7 +13261,7 @@ export type components = {
|
|||||||
* @default euler
|
* @default euler
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* UNet
|
* UNet
|
||||||
* @description UNet (scheduler, LoRAs)
|
* @description UNet (scheduler, LoRAs)
|
||||||
|
Loading…
Reference in New Issue
Block a user