mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker-modular_inpaint-2
This commit is contained in:
@ -354,7 +354,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@ -365,7 +365,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
|
@ -98,6 +98,9 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
try:
|
||||
|
@ -187,164 +187,171 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region ControlNet
|
||||
StarterModel(
|
||||
name="QRCode Monster",
|
||||
name="QRCode Monster v2 (SD1.5)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||
description="Controlnet model that generates scannable creative QR codes",
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="QRCode Monster (SDXL)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-tile-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
|
@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"set_seamless",
|
||||
]
|
||||
|
@ -83,47 +83,47 @@ class DenoiseContext:
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
# Current state of latent-space image in denoising process.
|
||||
# None until `pre_denoise_loop` callback.
|
||||
# None until `PRE_DENOISE_LOOP` callback.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latents: Optional[torch.Tensor] = None
|
||||
|
||||
# Current denoising step index.
|
||||
# None until `pre_step` callback.
|
||||
# None until `PRE_STEP` callback.
|
||||
step_index: Optional[int] = None
|
||||
|
||||
# Current denoising step timestep.
|
||||
# None until `pre_step` callback.
|
||||
# None until `PRE_STEP` callback.
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
|
||||
# Arguments which will be passed to UNet model.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
conditioning_mode: Optional[ConditioningMode] = None
|
||||
|
||||
# [TMP] Noise predictions from negative conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Noise predictions from positive conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Combined noise prediction from passed conditionings.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
|
@ -76,12 +76,12 @@ class StableDiffusionBackend:
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = self.apply_cfg(ctx)
|
||||
# ext: override combine_noise_preds
|
||||
ctx.noise_pred = self.combine_noise_preds(ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
# TODO: rename
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||
@ -95,13 +95,15 @@ class StableDiffusionBackend:
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
|
||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
|
||||
# in slightly different outputs. It is suspected that this is caused by small precision differences.
|
||||
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||
sample = ctx.latent_model_input
|
||||
|
@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
|
||||
POST_STEP = "post_step"
|
||||
PRE_UNET = "pre_unet"
|
||||
POST_UNET = "post_unet"
|
||||
POST_APPLY_CFG = "post_apply_cfg"
|
||||
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -52,9 +52,9 @@ class ExtensionBase:
|
||||
return self._callbacks
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, context: DenoiseContext):
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
yield None
|
||||
|
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
class ControlNetExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
model: ControlNetModel,
|
||||
image: Image,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
control_mode: CONTROLNET_MODE_VALUES,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
self._control_mode = control_mode
|
||||
self._resize_mode = resize_mode
|
||||
|
||||
self._image_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
original_processors = self._model.attn_processors
|
||||
try:
|
||||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self._model.set_attn_processor(original_processors)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def resize_image(self, ctx: DenoiseContext):
|
||||
_, _, latent_height, latent_width = ctx.latents.shape
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
self._image_tensor = prepare_control_image(
|
||||
image=self._image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=ctx.latents.device,
|
||||
dtype=ctx.latents.dtype,
|
||||
control_mode=self._control_mode,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
# convert mode to internal flags
|
||||
soft_injection = self._control_mode in ["more_prompt", "more_control"]
|
||||
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
|
||||
|
||||
# no negative conditioning in cfg_injection mode
|
||||
if cfg_injection:
|
||||
if ctx.conditioning_mode == ConditioningMode.Negative:
|
||||
return
|
||||
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
|
||||
|
||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||
# add zeros as samples for negative conditioning
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
else:
|
||||
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
|
||||
|
||||
if (
|
||||
ctx.unet_kwargs.down_block_additional_residuals is None
|
||||
and ctx.unet_kwargs.mid_block_additional_residual is None
|
||||
):
|
||||
ctx.unet_kwargs.down_block_additional_residuals = down_samples
|
||||
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
ctx.unet_kwargs.down_block_additional_residuals = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(
|
||||
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
|
||||
)
|
||||
]
|
||||
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
|
||||
|
||||
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
|
||||
model_input = ctx.latent_model_input
|
||||
image_tensor = self._image_tensor
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
model_input = torch.cat([model_input] * 2)
|
||||
image_tensor = torch.cat([image_tensor] * 2)
|
||||
|
||||
cn_unet_kwargs = UNetKwargs(
|
||||
sample=model_input,
|
||||
timestep=ctx.timestep,
|
||||
encoder_hidden_states=None, # set later by conditioning
|
||||
cross_attention_kwargs=dict( # noqa: C408
|
||||
percent_through=ctx.step_index / total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
||||
|
||||
# get static weight, or weight corresponding to current step
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
tmp_kwargs = vars(cn_unet_kwargs)
|
||||
|
||||
# Remove kwargs not related to ControlNet unet
|
||||
# ControlNet guidance fields
|
||||
del tmp_kwargs["down_block_additional_residuals"]
|
||||
del tmp_kwargs["mid_block_additional_residual"]
|
||||
|
||||
# T2i Adapter guidance fields
|
||||
del tmp_kwargs["down_intrablock_additional_residuals"]
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = self._model(
|
||||
controlnet_cond=image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
**vars(cn_unet_kwargs),
|
||||
)
|
||||
|
||||
return down_samples, mid_sample
|
35
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
35
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
|
||||
class FreeUExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
freeu_config: FreeUConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self._freeu_config = freeu_config
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
unet.enable_freeu(
|
||||
b1=self._freeu_config.b1,
|
||||
b2=self._freeu_config.b2,
|
||||
s1=self._freeu_config.s1,
|
||||
s2=self._freeu_config.s2,
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
unet.disable_freeu()
|
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class RescaleCFGExt(ExtensionBase):
|
||||
def __init__(self, rescale_multiplier: float):
|
||||
super().__init__()
|
||||
self._rescale_multiplier = rescale_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
||||
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||
return x_final
|
||||
|
||||
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
||||
def rescale_noise_pred(self, ctx: DenoiseContext):
|
||||
if self._rescale_multiplier > 0:
|
||||
ctx.noise_pred = self._rescale_cfg(
|
||||
ctx.noise_pred,
|
||||
ctx.positive_noise_pred,
|
||||
self._rescale_multiplier,
|
||||
)
|
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
|
||||
class SeamlessExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
seamless_axes: List[str],
|
||||
):
|
||||
super().__init__()
|
||||
self._seamless_axes = seamless_axes
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
with self.static_patch_model(
|
||||
model=unet,
|
||||
seamless_axes=self._seamless_axes,
|
||||
):
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def static_patch_model(
|
||||
model: torch.nn.Module,
|
||||
seamless_axes: List[str],
|
||||
):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
try:
|
||||
for layer in model.modules():
|
||||
if not isinstance(layer, torch.nn.Conv2d):
|
||||
continue
|
||||
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import T2IAdapter
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class T2IAdapterExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
node_context: InvocationContext,
|
||||
model_id: ModelIdentifierField,
|
||||
image: Image,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
super().__init__()
|
||||
self._node_context = node_context
|
||||
self._model_id = model_id
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._resize_mode = resize_mode
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
self._adapter_state: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
model_config = self._node_context.models.get_config(self._model_id.key)
|
||||
if model_config.base == BaseModelType.StableDiffusion1:
|
||||
self._max_unet_downscale = 8
|
||||
elif model_config.base == BaseModelType.StableDiffusionXL:
|
||||
self._max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
|
||||
|
||||
@callback(ExtensionCallbackType.SETUP)
|
||||
def setup(self, ctx: DenoiseContext):
|
||||
t2i_model: T2IAdapter
|
||||
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||
|
||||
self._adapter_state = self._run_model(
|
||||
model=t2i_model,
|
||||
image=self._image,
|
||||
latents_height=latents_height,
|
||||
latents_width=latents_width,
|
||||
)
|
||||
|
||||
def _run_model(
|
||||
self,
|
||||
model: T2IAdapter,
|
||||
image: Image,
|
||||
latents_height: int,
|
||||
latents_width: int,
|
||||
):
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
|
||||
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||
# T2I-Adapter model.
|
||||
#
|
||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||
# of the same requirements (e.g. preserving binary masks during resize).
|
||||
t2i_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=input_width,
|
||||
height=input_height,
|
||||
num_channels=model.config["in_channels"],
|
||||
device=model.device,
|
||||
dtype=model.dtype,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
return model(t2i_image)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
adapter_state = self._adapter_state
|
||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
|
||||
|
||||
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
|
||||
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
|
||||
else:
|
||||
for i, value in enumerate(adapter_state):
|
||||
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
|
@ -52,20 +52,24 @@ class ExtensionsManager:
|
||||
cb.function(ctx)
|
||||
|
||||
@contextmanager
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
def patch_extensions(self, ctx: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
exit_stack.enter_context(ext.patch_extension(ctx))
|
||||
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: create logic in PR with extension which uses it
|
||||
yield None
|
||||
# TODO: create weight patch logic in PR with extension which uses it
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
||||
|
||||
yield None
|
||||
|
@ -1,51 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
|
||||
try:
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
conv_layers: List[torch.nn.Conv2d] = []
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
conv_layers.append(module)
|
||||
|
||||
for layer in conv_layers:
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
Reference in New Issue
Block a user