mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Base code from draft PR
This commit is contained in:
parent
712cf00a82
commit
9cc852cf7f
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
@ -53,6 +55,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
TextConditioningRegions,
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
@ -314,9 +319,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
latent_height: int,
|
latent_height: int,
|
||||||
latent_width: int,
|
latent_width: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
cfg_scale: float | list[float],
|
cfg_scale: float | list[float],
|
||||||
steps: int,
|
steps: int,
|
||||||
cfg_rescale_multiplier: float,
|
cfg_rescale_multiplier: float,
|
||||||
@ -330,10 +336,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
uncond_list = [uncond_list]
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
cond_list, context, unet.device, unet.dtype
|
cond_list, context, device, dtype
|
||||||
)
|
)
|
||||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
uncond_list, context, unet.device, unet.dtype
|
uncond_list, context, device, dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
@ -341,14 +347,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
masks=cond_text_embedding_masks,
|
masks=cond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
text_conditionings=uncond_text_embeddings,
|
text_conditionings=uncond_text_embeddings,
|
||||||
masks=uncond_text_embedding_masks,
|
masks=uncond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(cfg_scale, list):
|
if isinstance(cfg_scale, list):
|
||||||
@ -707,9 +713,99 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return seed, noise, latents
|
return seed, noise, latents
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
if os.environ.get("USE_MODULAR_DENOISE", False):
|
||||||
|
return self._new_invoke(context)
|
||||||
|
else:
|
||||||
|
return self._old_invoke(context)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
with ExitStack() as exit_stack:
|
||||||
|
ext_manager = ExtensionsManager()
|
||||||
|
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context,
|
||||||
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
|
cfg_scale=self.cfg_scale,
|
||||||
|
steps=self.steps,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
# TODO: old backend, remove
|
||||||
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
|
scheduler,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=self.denoising_start,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
# ext: t2i/ip adapter
|
||||||
|
ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager)
|
||||||
|
|
||||||
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
with (
|
||||||
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
|
# ext: controlnet
|
||||||
|
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
|
||||||
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
|
ext_manager.patch_unet(model_state_dict, unet),
|
||||||
|
):
|
||||||
|
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||||
|
denoise_ctx.unet = unet
|
||||||
|
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
result_latents = result_latents.to("cpu") # TODO: detach?
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=result_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
|
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
@ -788,7 +884,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
positive_conditioning_field=self.positive_conditioning,
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
negative_conditioning_field=self.negative_conditioning,
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
unet=unet,
|
device=unet.device,
|
||||||
|
dtype=unet.dtype,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
cfg_scale=self.cfg_scale,
|
cfg_scale=self.cfg_scale,
|
||||||
|
60
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
60
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UNetKwargs:
|
||||||
|
sample: torch.Tensor
|
||||||
|
timestep: Union[torch.Tensor, float, int]
|
||||||
|
encoder_hidden_states: torch.Tensor
|
||||||
|
|
||||||
|
class_labels: Optional[torch.Tensor] = None
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||||||
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None
|
||||||
|
# return_dict: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DenoiseContext:
|
||||||
|
latents: torch.Tensor
|
||||||
|
scheduler_step_kwargs: dict[str, Any]
|
||||||
|
conditioning_data: TextConditioningData
|
||||||
|
noise: Optional[torch.Tensor]
|
||||||
|
seed: int
|
||||||
|
timesteps: torch.Tensor
|
||||||
|
init_timestep: torch.Tensor
|
||||||
|
|
||||||
|
scheduler: SchedulerMixin
|
||||||
|
unet: Optional[UNet2DConditionModel] = None
|
||||||
|
|
||||||
|
orig_latents: Optional[torch.Tensor] = None
|
||||||
|
step_index: Optional[int] = None
|
||||||
|
timestep: Optional[torch.Tensor] = None
|
||||||
|
unet_kwargs: Optional[UNetKwargs] = None
|
||||||
|
step_output: Optional[SchedulerOutput] = None
|
||||||
|
|
||||||
|
latent_model_input: Optional[torch.Tensor] = None
|
||||||
|
conditioning_mode: Optional[str] = None
|
||||||
|
negative_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
positive_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
extra: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __delattr__(self, name: str):
|
||||||
|
setattr(self, name, None)
|
@ -5,6 +5,7 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -103,7 +104,7 @@ class TextConditioningData:
|
|||||||
uncond_regions: Optional[TextConditioningRegions],
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
cond_regions: Optional[TextConditioningRegions],
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
guidance_scale: Union[float, List[float]],
|
guidance_scale: Union[float, List[float]],
|
||||||
guidance_rescale_multiplier: float = 0,
|
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
||||||
):
|
):
|
||||||
self.uncond_text = uncond_text
|
self.uncond_text = uncond_text
|
||||||
self.cond_text = cond_text
|
self.cond_text = cond_text
|
||||||
@ -114,6 +115,7 @@ class TextConditioningData:
|
|||||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
self.guidance_scale = guidance_scale
|
self.guidance_scale = guidance_scale
|
||||||
|
# TODO: old backend, remove
|
||||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
@ -121,3 +123,127 @@ class TextConditioningData:
|
|||||||
def is_sdxl(self):
|
def is_sdxl(self):
|
||||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
|
||||||
|
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
|
||||||
|
if conditioning_mode == "both":
|
||||||
|
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
|
self.uncond_text.embeds, self.cond_text.embeds
|
||||||
|
)
|
||||||
|
elif conditioning_mode == "positive":
|
||||||
|
encoder_hidden_states = self.cond_text.embeds
|
||||||
|
encoder_attention_mask = None
|
||||||
|
else: # elif conditioning_mode == "negative":
|
||||||
|
encoder_hidden_states = self.uncond_text.embeds
|
||||||
|
encoder_attention_mask = None
|
||||||
|
|
||||||
|
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||||
|
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||||
|
|
||||||
|
if self.is_sdxl():
|
||||||
|
if conditioning_mode == "negative":
|
||||||
|
added_cond_kwargs = dict( # noqa: C408
|
||||||
|
text_embeds=self.cond_text.pooled_embeds,
|
||||||
|
time_ids=self.cond_text.add_time_ids,
|
||||||
|
)
|
||||||
|
elif conditioning_mode == "positive":
|
||||||
|
added_cond_kwargs = dict( # noqa: C408
|
||||||
|
text_embeds=self.uncond_text.pooled_embeds,
|
||||||
|
time_ids=self.uncond_text.add_time_ids,
|
||||||
|
)
|
||||||
|
else: # elif conditioning_mode == "both":
|
||||||
|
added_cond_kwargs = dict( # noqa: C408
|
||||||
|
text_embeds=torch.cat(
|
||||||
|
[
|
||||||
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
|
self.uncond_text.pooled_embeds,
|
||||||
|
self.cond_text.pooled_embeds,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
time_ids=torch.cat(
|
||||||
|
[
|
||||||
|
self.uncond_text.add_time_ids,
|
||||||
|
self.cond_text.add_time_ids,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||||
|
|
||||||
|
if self.cond_regions is not None or self.uncond_regions is not None:
|
||||||
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
|
|
||||||
|
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
|
||||||
|
_, _, h, w = _tmp_regions.masks.shape
|
||||||
|
dtype = self.cond_text.embeds.dtype
|
||||||
|
device = self.cond_text.embeds.device
|
||||||
|
|
||||||
|
regions = []
|
||||||
|
for c, r in [
|
||||||
|
(self.uncond_text, self.uncond_regions),
|
||||||
|
(self.cond_text, self.cond_regions),
|
||||||
|
]:
|
||||||
|
if r is None:
|
||||||
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
)
|
||||||
|
regions.append(r)
|
||||||
|
|
||||||
|
if unet_kwargs.cross_attention_kwargs is None:
|
||||||
|
unet_kwargs.cross_attention_kwargs = {}
|
||||||
|
|
||||||
|
unet_kwargs.cross_attention_kwargs.update(
|
||||||
|
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones(
|
||||||
|
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat(
|
||||||
|
[
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cond = torch.cat(
|
||||||
|
[
|
||||||
|
cond,
|
||||||
|
torch.zeros(
|
||||||
|
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||||
|
device=cond.device,
|
||||||
|
dtype=cond.dtype,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat(
|
||||||
|
[
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||||
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||||
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
if TYPE_CHECKING:
|
||||||
TextConditioningRegions,
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
)
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
|
220
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
220
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
|
def trim_to_multiple_of(*args, multiple_of=8):
|
||||||
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
|
|
||||||
|
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param image: input image
|
||||||
|
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||||
|
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||||
|
"""
|
||||||
|
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
|
||||||
|
transformation = T.Compose(
|
||||||
|
[
|
||||||
|
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
|
||||||
|
T.ToTensor(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tensor = transformation(image)
|
||||||
|
if normalize:
|
||||||
|
tensor = tensor * 2.0 - 1.0
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionBackend:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
scheduler: SchedulerMixin,
|
||||||
|
):
|
||||||
|
self.unet = unet
|
||||||
|
self.scheduler = scheduler
|
||||||
|
config = get_config()
|
||||||
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
|
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
if ctx.init_timestep.shape[0] == 0:
|
||||||
|
return ctx.latents
|
||||||
|
|
||||||
|
ctx.orig_latents = ctx.latents.clone()
|
||||||
|
|
||||||
|
if ctx.noise is not None:
|
||||||
|
batch_size = ctx.latents.shape[0]
|
||||||
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
|
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size))
|
||||||
|
|
||||||
|
# if no work to do, return latents
|
||||||
|
if ctx.timesteps.shape[0] == 0:
|
||||||
|
return ctx.latents
|
||||||
|
|
||||||
|
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||||
|
# ext: preview[pre_denoise_loop, priority=low]
|
||||||
|
ext_manager.modifiers.pre_denoise_loop(ctx)
|
||||||
|
|
||||||
|
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
|
||||||
|
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||||
|
ext_manager.modifiers.pre_step(ctx)
|
||||||
|
|
||||||
|
# ext: tiles? [override: step]
|
||||||
|
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
|
||||||
|
|
||||||
|
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||||
|
# ext: preview[post_step, priority=low]
|
||||||
|
ext_manager.modifiers.post_step(ctx)
|
||||||
|
|
||||||
|
ctx.latents = ctx.step_output.prev_sample
|
||||||
|
|
||||||
|
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
||||||
|
ext_manager.modifiers.post_denoise_loop(ctx)
|
||||||
|
return ctx.latents
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
|
||||||
|
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
if self.sequential_guidance:
|
||||||
|
conditioning_call = self._apply_standard_conditioning_sequentially
|
||||||
|
else:
|
||||||
|
conditioning_call = self._apply_standard_conditioning
|
||||||
|
|
||||||
|
# not sure if here needed override
|
||||||
|
ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager)
|
||||||
|
|
||||||
|
# ext: override combine_noise
|
||||||
|
ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx)
|
||||||
|
|
||||||
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
|
ext_manager.modifiers.modify_noise_prediction(ctx)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
|
||||||
|
|
||||||
|
# del locals
|
||||||
|
del ctx.latent_model_input
|
||||||
|
del ctx.negative_noise_pred
|
||||||
|
del ctx.positive_noise_pred
|
||||||
|
del ctx.noise_pred
|
||||||
|
|
||||||
|
return step_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_noise(ctx: DenoiseContext) -> torch.Tensor:
|
||||||
|
guidance_scale = ctx.conditioning_data.guidance_scale
|
||||||
|
if isinstance(guidance_scale, list):
|
||||||
|
guidance_scale = guidance_scale[ctx.step_index]
|
||||||
|
|
||||||
|
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||||
|
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||||
|
|
||||||
|
def _apply_standard_conditioning(
|
||||||
|
self, ctx: DenoiseContext, ext_manager: ExtensionsManager
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||||
|
the cost of higher memory usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ctx.unet_kwargs = UNetKwargs(
|
||||||
|
sample=torch.cat([ctx.latent_model_input] * 2),
|
||||||
|
timestep=ctx.timestep,
|
||||||
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
|
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.conditioning_mode = "both"
|
||||||
|
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||||
|
|
||||||
|
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||||
|
ext_manager.modifiers.pre_unet_forward(ctx)
|
||||||
|
|
||||||
|
# ext: inpaint [pre_unet_forward, priority=low]
|
||||||
|
# or
|
||||||
|
# ext: inpaint [override: unet_forward]
|
||||||
|
both_results = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||||
|
negative_next_x, positive_next_x = both_results.chunk(2)
|
||||||
|
# del locals
|
||||||
|
del ctx.unet_kwargs
|
||||||
|
del ctx.conditioning_mode
|
||||||
|
return negative_next_x, positive_next_x
|
||||||
|
|
||||||
|
def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
|
slower execution speed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
###################
|
||||||
|
# Negative pass
|
||||||
|
###################
|
||||||
|
|
||||||
|
ctx.unet_kwargs = UNetKwargs(
|
||||||
|
sample=ctx.latent_model_input,
|
||||||
|
timestep=ctx.timestep,
|
||||||
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
|
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.conditioning_mode = "negative"
|
||||||
|
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative")
|
||||||
|
|
||||||
|
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||||
|
ext_manager.modifiers.pre_unet_forward(ctx)
|
||||||
|
|
||||||
|
# ext: inpaint [pre_unet_forward, priority=low]
|
||||||
|
# or
|
||||||
|
# ext: inpaint [override: unet_forward]
|
||||||
|
negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||||
|
|
||||||
|
del ctx.unet_kwargs
|
||||||
|
del ctx.conditioning_mode
|
||||||
|
# TODO: gc.collect() ?
|
||||||
|
|
||||||
|
###################
|
||||||
|
# Positive pass
|
||||||
|
###################
|
||||||
|
|
||||||
|
ctx.unet_kwargs = UNetKwargs(
|
||||||
|
sample=ctx.latent_model_input,
|
||||||
|
timestep=ctx.timestep,
|
||||||
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
|
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.conditioning_mode = "positive"
|
||||||
|
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive")
|
||||||
|
|
||||||
|
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||||
|
ext_manager.modifiers.pre_unet_forward(ctx)
|
||||||
|
|
||||||
|
# ext: inpaint [pre_unet_forward, priority=low]
|
||||||
|
# or
|
||||||
|
# ext: inpaint [override: unet_forward]
|
||||||
|
positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||||
|
|
||||||
|
del ctx.unet_kwargs
|
||||||
|
del ctx.conditioning_mode
|
||||||
|
# TODO: gc.collect() ?
|
||||||
|
|
||||||
|
return negative_next_x, positive_next_x
|
||||||
|
|
||||||
|
def _unet_forward(self, **kwargs) -> torch.Tensor:
|
||||||
|
return self.unet(**kwargs).sample
|
9
invokeai/backend/stable_diffusion/extensions/__init__.py
Normal file
9
invokeai/backend/stable_diffusion/extensions/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
Initialization file for the invokeai.backend.stable_diffusion.extensions package
|
||||||
|
"""
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ExtensionBase",
|
||||||
|
]
|
58
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
58
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InjectionInfo:
|
||||||
|
type: str
|
||||||
|
name: str
|
||||||
|
order: Optional[str]
|
||||||
|
function: Callable
|
||||||
|
|
||||||
|
|
||||||
|
def modifier(name: str, order: str = "any"):
|
||||||
|
def _decorator(func):
|
||||||
|
func.__inj_info__ = {
|
||||||
|
"type": "modifier",
|
||||||
|
"name": name,
|
||||||
|
"order": order,
|
||||||
|
}
|
||||||
|
return func
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
def override(name: str):
|
||||||
|
def _decorator(func):
|
||||||
|
func.__inj_info__ = {
|
||||||
|
"type": "override",
|
||||||
|
"name": name,
|
||||||
|
"order": None,
|
||||||
|
}
|
||||||
|
return func
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionBase:
|
||||||
|
def __init__(self, priority: int):
|
||||||
|
self.priority = priority
|
||||||
|
self.injections: List[InjectionInfo] = []
|
||||||
|
for func_name in dir(self):
|
||||||
|
func = getattr(self, func_name)
|
||||||
|
if not callable(func) or not hasattr(func, "__inj_info__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_attention_processor(self, attention_processor_cls: object):
|
||||||
|
yield None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
|
yield None
|
195
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
195
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import ExitStack, contextmanager
|
||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Callable, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
||||||
|
|
||||||
|
|
||||||
|
class ExtModifiersApi(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def pre_denoise_loop(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post_denoise_loop(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_step(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post_step(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def modify_noise_prediction(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_unet_forward(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExtOverridesApi(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyCallsClass:
|
||||||
|
def __init__(self, handler):
|
||||||
|
self._handler = handler
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
return partial(self._handler, item)
|
||||||
|
|
||||||
|
|
||||||
|
class ModifierInjectionPoint:
|
||||||
|
def __init__(self):
|
||||||
|
self.first = []
|
||||||
|
self.any = []
|
||||||
|
self.last = []
|
||||||
|
|
||||||
|
def add(self, func: Callable, order: str):
|
||||||
|
if order == "first":
|
||||||
|
self.first.append(func)
|
||||||
|
elif order == "last":
|
||||||
|
self.last.append(func)
|
||||||
|
else: # elif order == "any":
|
||||||
|
self.any.append(func)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
for func in self.first:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
for func in self.any:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
for func in reversed(self.last):
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionsManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.extensions = []
|
||||||
|
|
||||||
|
self._overrides = {}
|
||||||
|
self._modifiers = {}
|
||||||
|
|
||||||
|
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
|
||||||
|
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
|
||||||
|
|
||||||
|
def add_extension(self, ext: ExtensionBase):
|
||||||
|
self.extensions.append(ext)
|
||||||
|
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||||
|
|
||||||
|
self._overrides.clear()
|
||||||
|
self._modifiers.clear()
|
||||||
|
|
||||||
|
for ext in ordered_extensions:
|
||||||
|
for inj_info in ext.injections:
|
||||||
|
if inj_info.type == "modifier":
|
||||||
|
if inj_info.name not in self._modifiers:
|
||||||
|
self._modifiers[inj_info.name] = ModifierInjectionPoint()
|
||||||
|
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if inj_info.name in self._overrides:
|
||||||
|
raise Exception(f"Already overloaded - {inj_info.name}")
|
||||||
|
self._overrides[inj_info.name] = inj_info.function
|
||||||
|
|
||||||
|
def call_modifier(self, name: str, *args, **kwargs):
|
||||||
|
if name in self._modifiers:
|
||||||
|
self._modifiers[name](*args, **kwargs)
|
||||||
|
|
||||||
|
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
|
||||||
|
if name in self._overrides:
|
||||||
|
return self._overrides[name](orig_func, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return orig_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: is there any need in such high abstarction
|
||||||
|
# @contextmanager
|
||||||
|
# def patch_extensions(self):
|
||||||
|
# exit_stack = ExitStack()
|
||||||
|
# try:
|
||||||
|
# for ext in self.extensions:
|
||||||
|
# exit_stack.enter_context(ext.patch_extension(self))
|
||||||
|
#
|
||||||
|
# yield None
|
||||||
|
#
|
||||||
|
# finally:
|
||||||
|
# exit_stack.close()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
|
||||||
|
unet_orig_processors = unet.attn_processors
|
||||||
|
exit_stack = ExitStack()
|
||||||
|
try:
|
||||||
|
# just to be sure that attentions have not same processor instance
|
||||||
|
attn_procs = {}
|
||||||
|
for name in unet.attn_processors.keys():
|
||||||
|
attn_procs[name] = attn_processor_cls()
|
||||||
|
unet.set_attn_processor(attn_procs)
|
||||||
|
|
||||||
|
for ext in self.extensions:
|
||||||
|
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
|
||||||
|
|
||||||
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
unet.set_attn_processor(unet_orig_processors)
|
||||||
|
exit_stack.close()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
|
exit_stack = ExitStack()
|
||||||
|
try:
|
||||||
|
changed_keys = set()
|
||||||
|
changed_unknown_keys = {}
|
||||||
|
|
||||||
|
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||||
|
for ext in ordered_extensions:
|
||||||
|
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
|
||||||
|
if patch_result is None:
|
||||||
|
continue
|
||||||
|
new_keys, new_unk_keys = patch_result
|
||||||
|
changed_keys.update(new_keys)
|
||||||
|
# skip already seen keys, as new weight might be changed
|
||||||
|
for k, v in new_unk_keys.items():
|
||||||
|
if k in changed_unknown_keys:
|
||||||
|
continue
|
||||||
|
changed_unknown_keys[k] = v
|
||||||
|
|
||||||
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
exit_stack.close()
|
||||||
|
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
|
with torch.no_grad():
|
||||||
|
for module_key in changed_keys:
|
||||||
|
weight = state_dict[module_key]
|
||||||
|
unet.get_submodule(module_key).weight.copy_(
|
||||||
|
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||||
|
)
|
||||||
|
for module_key, weight in changed_unknown_keys.items():
|
||||||
|
unet.get_submodule(module_key).weight.copy_(
|
||||||
|
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user